1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, ThreadStatus,
4 UserMessageId,
5};
6use agent_client_protocol::{self as acp};
7use agent_settings::AgentProfileId;
8use anyhow::Result;
9use client::{Client, UserStore};
10use cloud_llm_client::CompletionIntent;
11use collections::IndexMap;
12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
13use feature_flags::FeatureFlagAppExt as _;
14use fs::{FakeFs, Fs};
15use futures::{
16 FutureExt as _, StreamExt,
17 channel::{
18 mpsc::{self, UnboundedReceiver},
19 oneshot,
20 },
21 future::{Fuse, Shared},
22};
23use gpui::{
24 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
25 http_client::FakeHttpClient,
26};
27use indoc::indoc;
28use language_model::{
29 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
30 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
31 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
32 LanguageModelToolUse, MessageContent, Role, StopReason, TokenUsage,
33 fake_provider::FakeLanguageModel,
34};
35use pretty_assertions::assert_eq;
36use project::{
37 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
38};
39use prompt_store::ProjectContext;
40use reqwest_client::ReqwestClient;
41use schemars::JsonSchema;
42use serde::{Deserialize, Serialize};
43use serde_json::json;
44use settings::{Settings, SettingsStore};
45use std::{
46 path::Path,
47 pin::Pin,
48 rc::Rc,
49 sync::{
50 Arc,
51 atomic::{AtomicBool, Ordering},
52 },
53 time::Duration,
54};
55use util::path;
56
57mod edit_file_thread_test;
58mod test_tools;
59use test_tools::*;
60
61fn init_test(cx: &mut TestAppContext) {
62 cx.update(|cx| {
63 let settings_store = SettingsStore::test(cx);
64 cx.set_global(settings_store);
65 });
66}
67
68struct FakeTerminalHandle {
69 killed: Arc<AtomicBool>,
70 stopped_by_user: Arc<AtomicBool>,
71 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
72 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
73 output: acp::TerminalOutputResponse,
74 id: acp::TerminalId,
75}
76
77impl FakeTerminalHandle {
78 fn new_never_exits(cx: &mut App) -> Self {
79 let killed = Arc::new(AtomicBool::new(false));
80 let stopped_by_user = Arc::new(AtomicBool::new(false));
81
82 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
83
84 let wait_for_exit = cx
85 .spawn(async move |_cx| {
86 // Wait for the exit signal (sent when kill() is called)
87 let _ = exit_receiver.await;
88 acp::TerminalExitStatus::new()
89 })
90 .shared();
91
92 Self {
93 killed,
94 stopped_by_user,
95 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
96 wait_for_exit,
97 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
98 id: acp::TerminalId::new("fake_terminal".to_string()),
99 }
100 }
101
102 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
103 let killed = Arc::new(AtomicBool::new(false));
104 let stopped_by_user = Arc::new(AtomicBool::new(false));
105 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
106
107 let wait_for_exit = cx
108 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
109 .shared();
110
111 Self {
112 killed,
113 stopped_by_user,
114 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
115 wait_for_exit,
116 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
117 id: acp::TerminalId::new("fake_terminal".to_string()),
118 }
119 }
120
121 fn was_killed(&self) -> bool {
122 self.killed.load(Ordering::SeqCst)
123 }
124
125 fn set_stopped_by_user(&self, stopped: bool) {
126 self.stopped_by_user.store(stopped, Ordering::SeqCst);
127 }
128
129 fn signal_exit(&self) {
130 if let Some(sender) = self.exit_sender.borrow_mut().take() {
131 let _ = sender.send(());
132 }
133 }
134}
135
136impl crate::TerminalHandle for FakeTerminalHandle {
137 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
138 Ok(self.id.clone())
139 }
140
141 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
142 Ok(self.output.clone())
143 }
144
145 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
146 Ok(self.wait_for_exit.clone())
147 }
148
149 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
150 self.killed.store(true, Ordering::SeqCst);
151 self.signal_exit();
152 Ok(())
153 }
154
155 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
156 Ok(self.stopped_by_user.load(Ordering::SeqCst))
157 }
158}
159
160struct FakeSubagentHandle {
161 session_id: acp::SessionId,
162 wait_for_summary_task: Shared<Task<String>>,
163}
164
165impl SubagentHandle for FakeSubagentHandle {
166 fn id(&self) -> acp::SessionId {
167 self.session_id.clone()
168 }
169
170 fn send(&self, _message: String, cx: &AsyncApp) -> Task<Result<String>> {
171 let task = self.wait_for_summary_task.clone();
172 cx.background_spawn(async move { Ok(task.await) })
173 }
174}
175
176#[derive(Default)]
177struct FakeThreadEnvironment {
178 terminal_handle: Option<Rc<FakeTerminalHandle>>,
179 subagent_handle: Option<Rc<FakeSubagentHandle>>,
180}
181
182impl FakeThreadEnvironment {
183 pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
184 Self {
185 terminal_handle: Some(terminal_handle.into()),
186 ..self
187 }
188 }
189}
190
191impl crate::ThreadEnvironment for FakeThreadEnvironment {
192 fn create_terminal(
193 &self,
194 _command: String,
195 _cwd: Option<std::path::PathBuf>,
196 _output_byte_limit: Option<u64>,
197 _cx: &mut AsyncApp,
198 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
199 let handle = self
200 .terminal_handle
201 .clone()
202 .expect("Terminal handle not available on FakeThreadEnvironment");
203 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
204 }
205
206 fn create_subagent(&self, _label: String, _cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
207 Ok(self
208 .subagent_handle
209 .clone()
210 .expect("Subagent handle not available on FakeThreadEnvironment")
211 as Rc<dyn SubagentHandle>)
212 }
213}
214
215/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
216struct MultiTerminalEnvironment {
217 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
218}
219
220impl MultiTerminalEnvironment {
221 fn new() -> Self {
222 Self {
223 handles: std::cell::RefCell::new(Vec::new()),
224 }
225 }
226
227 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
228 self.handles.borrow().clone()
229 }
230}
231
232impl crate::ThreadEnvironment for MultiTerminalEnvironment {
233 fn create_terminal(
234 &self,
235 _command: String,
236 _cwd: Option<std::path::PathBuf>,
237 _output_byte_limit: Option<u64>,
238 cx: &mut AsyncApp,
239 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
240 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
241 self.handles.borrow_mut().push(handle.clone());
242 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
243 }
244
245 fn create_subagent(&self, _label: String, _cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
246 unimplemented!()
247 }
248}
249
250fn always_allow_tools(cx: &mut TestAppContext) {
251 cx.update(|cx| {
252 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
253 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
254 agent_settings::AgentSettings::override_global(settings, cx);
255 });
256}
257
258#[gpui::test]
259async fn test_echo(cx: &mut TestAppContext) {
260 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
261 let fake_model = model.as_fake();
262
263 let events = thread
264 .update(cx, |thread, cx| {
265 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
266 })
267 .unwrap();
268 cx.run_until_parked();
269 fake_model.send_last_completion_stream_text_chunk("Hello");
270 fake_model
271 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
272 fake_model.end_last_completion_stream();
273
274 let events = events.collect().await;
275 thread.update(cx, |thread, _cx| {
276 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
277 assert_eq!(thread.last_message().unwrap().to_markdown(), "Hello\n")
278 });
279 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
280}
281
282#[gpui::test]
283async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
284 init_test(cx);
285 always_allow_tools(cx);
286
287 let fs = FakeFs::new(cx.executor());
288 let project = Project::test(fs, [], cx).await;
289
290 let environment = Rc::new(cx.update(|cx| {
291 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
292 }));
293 let handle = environment.terminal_handle.clone().unwrap();
294
295 #[allow(clippy::arc_with_non_send_sync)]
296 let tool = Arc::new(crate::TerminalTool::new(project, environment));
297 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
298
299 let task = cx.update(|cx| {
300 tool.run(
301 ToolInput::resolved(crate::TerminalToolInput {
302 command: "sleep 1000".to_string(),
303 cd: ".".to_string(),
304 timeout_ms: Some(5),
305 }),
306 event_stream,
307 cx,
308 )
309 });
310
311 let update = rx.expect_update_fields().await;
312 assert!(
313 update.content.iter().any(|blocks| {
314 blocks
315 .iter()
316 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
317 }),
318 "expected tool call update to include terminal content"
319 );
320
321 let mut task_future: Pin<Box<Fuse<Task<Result<String, String>>>>> = Box::pin(task.fuse());
322
323 let deadline = std::time::Instant::now() + Duration::from_millis(500);
324 loop {
325 if let Some(result) = task_future.as_mut().now_or_never() {
326 let result = result.expect("terminal tool task should complete");
327
328 assert!(
329 handle.was_killed(),
330 "expected terminal handle to be killed on timeout"
331 );
332 assert!(
333 result.contains("partial output"),
334 "expected result to include terminal output, got: {result}"
335 );
336 return;
337 }
338
339 if std::time::Instant::now() >= deadline {
340 panic!("timed out waiting for terminal tool task to complete");
341 }
342
343 cx.run_until_parked();
344 cx.background_executor.timer(Duration::from_millis(1)).await;
345 }
346}
347
348#[gpui::test]
349#[ignore]
350async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
351 init_test(cx);
352 always_allow_tools(cx);
353
354 let fs = FakeFs::new(cx.executor());
355 let project = Project::test(fs, [], cx).await;
356
357 let environment = Rc::new(cx.update(|cx| {
358 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
359 }));
360 let handle = environment.terminal_handle.clone().unwrap();
361
362 #[allow(clippy::arc_with_non_send_sync)]
363 let tool = Arc::new(crate::TerminalTool::new(project, environment));
364 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
365
366 let _task = cx.update(|cx| {
367 tool.run(
368 ToolInput::resolved(crate::TerminalToolInput {
369 command: "sleep 1000".to_string(),
370 cd: ".".to_string(),
371 timeout_ms: None,
372 }),
373 event_stream,
374 cx,
375 )
376 });
377
378 let update = rx.expect_update_fields().await;
379 assert!(
380 update.content.iter().any(|blocks| {
381 blocks
382 .iter()
383 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
384 }),
385 "expected tool call update to include terminal content"
386 );
387
388 cx.background_executor
389 .timer(Duration::from_millis(25))
390 .await;
391
392 assert!(
393 !handle.was_killed(),
394 "did not expect terminal handle to be killed without a timeout"
395 );
396}
397
398#[gpui::test]
399async fn test_thinking(cx: &mut TestAppContext) {
400 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
401 let fake_model = model.as_fake();
402
403 let events = thread
404 .update(cx, |thread, cx| {
405 thread.send(
406 UserMessageId::new(),
407 [indoc! {"
408 Testing:
409
410 Generate a thinking step where you just think the word 'Think',
411 and have your final answer be 'Hello'
412 "}],
413 cx,
414 )
415 })
416 .unwrap();
417 cx.run_until_parked();
418 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
419 text: "Think".to_string(),
420 signature: None,
421 });
422 fake_model.send_last_completion_stream_text_chunk("Hello");
423 fake_model
424 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
425 fake_model.end_last_completion_stream();
426
427 let events = events.collect().await;
428 thread.update(cx, |thread, _cx| {
429 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
430 assert_eq!(
431 thread.last_message().unwrap().to_markdown(),
432 indoc! {"
433 <think>Think</think>
434 Hello
435 "}
436 )
437 });
438 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
439}
440
441#[gpui::test]
442async fn test_system_prompt(cx: &mut TestAppContext) {
443 let ThreadTest {
444 model,
445 thread,
446 project_context,
447 ..
448 } = setup(cx, TestModel::Fake).await;
449 let fake_model = model.as_fake();
450
451 project_context.update(cx, |project_context, _cx| {
452 project_context.shell = "test-shell".into()
453 });
454 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
455 thread
456 .update(cx, |thread, cx| {
457 thread.send(UserMessageId::new(), ["abc"], cx)
458 })
459 .unwrap();
460 cx.run_until_parked();
461 let mut pending_completions = fake_model.pending_completions();
462 assert_eq!(
463 pending_completions.len(),
464 1,
465 "unexpected pending completions: {:?}",
466 pending_completions
467 );
468
469 let pending_completion = pending_completions.pop().unwrap();
470 assert_eq!(pending_completion.messages[0].role, Role::System);
471
472 let system_message = &pending_completion.messages[0];
473 let system_prompt = system_message.content[0].to_str().unwrap();
474 assert!(
475 system_prompt.contains("test-shell"),
476 "unexpected system message: {:?}",
477 system_message
478 );
479 assert!(
480 system_prompt.contains("## Fixing Diagnostics"),
481 "unexpected system message: {:?}",
482 system_message
483 );
484}
485
486#[gpui::test]
487async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
488 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
489 let fake_model = model.as_fake();
490
491 thread
492 .update(cx, |thread, cx| {
493 thread.send(UserMessageId::new(), ["abc"], cx)
494 })
495 .unwrap();
496 cx.run_until_parked();
497 let mut pending_completions = fake_model.pending_completions();
498 assert_eq!(
499 pending_completions.len(),
500 1,
501 "unexpected pending completions: {:?}",
502 pending_completions
503 );
504
505 let pending_completion = pending_completions.pop().unwrap();
506 assert_eq!(pending_completion.messages[0].role, Role::System);
507
508 let system_message = &pending_completion.messages[0];
509 let system_prompt = system_message.content[0].to_str().unwrap();
510 assert!(
511 !system_prompt.contains("## Tool Use"),
512 "unexpected system message: {:?}",
513 system_message
514 );
515 assert!(
516 !system_prompt.contains("## Fixing Diagnostics"),
517 "unexpected system message: {:?}",
518 system_message
519 );
520}
521
522#[gpui::test]
523async fn test_prompt_caching(cx: &mut TestAppContext) {
524 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
525 let fake_model = model.as_fake();
526
527 // Send initial user message and verify it's cached
528 thread
529 .update(cx, |thread, cx| {
530 thread.send(UserMessageId::new(), ["Message 1"], cx)
531 })
532 .unwrap();
533 cx.run_until_parked();
534
535 let completion = fake_model.pending_completions().pop().unwrap();
536 assert_eq!(
537 completion.messages[1..],
538 vec![LanguageModelRequestMessage {
539 role: Role::User,
540 content: vec!["Message 1".into()],
541 cache: true,
542 reasoning_details: None,
543 }]
544 );
545 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
546 "Response to Message 1".into(),
547 ));
548 fake_model.end_last_completion_stream();
549 cx.run_until_parked();
550
551 // Send another user message and verify only the latest is cached
552 thread
553 .update(cx, |thread, cx| {
554 thread.send(UserMessageId::new(), ["Message 2"], cx)
555 })
556 .unwrap();
557 cx.run_until_parked();
558
559 let completion = fake_model.pending_completions().pop().unwrap();
560 assert_eq!(
561 completion.messages[1..],
562 vec![
563 LanguageModelRequestMessage {
564 role: Role::User,
565 content: vec!["Message 1".into()],
566 cache: false,
567 reasoning_details: None,
568 },
569 LanguageModelRequestMessage {
570 role: Role::Assistant,
571 content: vec!["Response to Message 1".into()],
572 cache: false,
573 reasoning_details: None,
574 },
575 LanguageModelRequestMessage {
576 role: Role::User,
577 content: vec!["Message 2".into()],
578 cache: true,
579 reasoning_details: None,
580 }
581 ]
582 );
583 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
584 "Response to Message 2".into(),
585 ));
586 fake_model.end_last_completion_stream();
587 cx.run_until_parked();
588
589 // Simulate a tool call and verify that the latest tool result is cached
590 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
591 thread
592 .update(cx, |thread, cx| {
593 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
594 })
595 .unwrap();
596 cx.run_until_parked();
597
598 let tool_use = LanguageModelToolUse {
599 id: "tool_1".into(),
600 name: EchoTool::NAME.into(),
601 raw_input: json!({"text": "test"}).to_string(),
602 input: json!({"text": "test"}),
603 is_input_complete: true,
604 thought_signature: None,
605 };
606 fake_model
607 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
608 fake_model.end_last_completion_stream();
609 cx.run_until_parked();
610
611 let completion = fake_model.pending_completions().pop().unwrap();
612 let tool_result = LanguageModelToolResult {
613 tool_use_id: "tool_1".into(),
614 tool_name: EchoTool::NAME.into(),
615 is_error: false,
616 content: "test".into(),
617 output: Some("test".into()),
618 };
619 assert_eq!(
620 completion.messages[1..],
621 vec![
622 LanguageModelRequestMessage {
623 role: Role::User,
624 content: vec!["Message 1".into()],
625 cache: false,
626 reasoning_details: None,
627 },
628 LanguageModelRequestMessage {
629 role: Role::Assistant,
630 content: vec!["Response to Message 1".into()],
631 cache: false,
632 reasoning_details: None,
633 },
634 LanguageModelRequestMessage {
635 role: Role::User,
636 content: vec!["Message 2".into()],
637 cache: false,
638 reasoning_details: None,
639 },
640 LanguageModelRequestMessage {
641 role: Role::Assistant,
642 content: vec!["Response to Message 2".into()],
643 cache: false,
644 reasoning_details: None,
645 },
646 LanguageModelRequestMessage {
647 role: Role::User,
648 content: vec!["Use the echo tool".into()],
649 cache: false,
650 reasoning_details: None,
651 },
652 LanguageModelRequestMessage {
653 role: Role::Assistant,
654 content: vec![MessageContent::ToolUse(tool_use)],
655 cache: false,
656 reasoning_details: None,
657 },
658 LanguageModelRequestMessage {
659 role: Role::User,
660 content: vec![MessageContent::ToolResult(tool_result)],
661 cache: true,
662 reasoning_details: None,
663 }
664 ]
665 );
666}
667
668#[gpui::test]
669#[cfg_attr(not(feature = "e2e"), ignore)]
670async fn test_basic_tool_calls(cx: &mut TestAppContext) {
671 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
672
673 // Test a tool call that's likely to complete *before* streaming stops.
674 let events = thread
675 .update(cx, |thread, cx| {
676 thread.add_tool(EchoTool);
677 thread.send(
678 UserMessageId::new(),
679 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
680 cx,
681 )
682 })
683 .unwrap()
684 .collect()
685 .await;
686 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
687
688 // Test a tool calls that's likely to complete *after* streaming stops.
689 let events = thread
690 .update(cx, |thread, cx| {
691 thread.remove_tool(&EchoTool::NAME);
692 thread.add_tool(DelayTool);
693 thread.send(
694 UserMessageId::new(),
695 [
696 "Now call the delay tool with 200ms.",
697 "When the timer goes off, then you echo the output of the tool.",
698 ],
699 cx,
700 )
701 })
702 .unwrap()
703 .collect()
704 .await;
705 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
706 thread.update(cx, |thread, _cx| {
707 assert!(
708 thread
709 .last_message()
710 .unwrap()
711 .as_agent_message()
712 .unwrap()
713 .content
714 .iter()
715 .any(|content| {
716 if let AgentMessageContent::Text(text) = content {
717 text.contains("Ding")
718 } else {
719 false
720 }
721 }),
722 "{}",
723 thread.to_markdown()
724 );
725 });
726}
727
728#[gpui::test]
729#[cfg_attr(not(feature = "e2e"), ignore)]
730async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
731 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
732
733 // Test a tool call that's likely to complete *before* streaming stops.
734 let mut events = thread
735 .update(cx, |thread, cx| {
736 thread.add_tool(WordListTool);
737 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
738 })
739 .unwrap();
740
741 let mut saw_partial_tool_use = false;
742 while let Some(event) = events.next().await {
743 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
744 thread.update(cx, |thread, _cx| {
745 // Look for a tool use in the thread's last message
746 let message = thread.last_message().unwrap();
747 let agent_message = message.as_agent_message().unwrap();
748 let last_content = agent_message.content.last().unwrap();
749 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
750 assert_eq!(last_tool_use.name.as_ref(), "word_list");
751 if tool_call.status == acp::ToolCallStatus::Pending {
752 if !last_tool_use.is_input_complete
753 && last_tool_use.input.get("g").is_none()
754 {
755 saw_partial_tool_use = true;
756 }
757 } else {
758 last_tool_use
759 .input
760 .get("a")
761 .expect("'a' has streamed because input is now complete");
762 last_tool_use
763 .input
764 .get("g")
765 .expect("'g' has streamed because input is now complete");
766 }
767 } else {
768 panic!("last content should be a tool use");
769 }
770 });
771 }
772 }
773
774 assert!(
775 saw_partial_tool_use,
776 "should see at least one partially streamed tool use in the history"
777 );
778}
779
780#[gpui::test]
781async fn test_tool_authorization(cx: &mut TestAppContext) {
782 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
783 let fake_model = model.as_fake();
784
785 let mut events = thread
786 .update(cx, |thread, cx| {
787 thread.add_tool(ToolRequiringPermission);
788 thread.send(UserMessageId::new(), ["abc"], cx)
789 })
790 .unwrap();
791 cx.run_until_parked();
792 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
793 LanguageModelToolUse {
794 id: "tool_id_1".into(),
795 name: ToolRequiringPermission::NAME.into(),
796 raw_input: "{}".into(),
797 input: json!({}),
798 is_input_complete: true,
799 thought_signature: None,
800 },
801 ));
802 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
803 LanguageModelToolUse {
804 id: "tool_id_2".into(),
805 name: ToolRequiringPermission::NAME.into(),
806 raw_input: "{}".into(),
807 input: json!({}),
808 is_input_complete: true,
809 thought_signature: None,
810 },
811 ));
812 fake_model.end_last_completion_stream();
813 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
814 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
815
816 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
817 tool_call_auth_1
818 .response
819 .send(acp::PermissionOptionId::new("allow"))
820 .unwrap();
821 cx.run_until_parked();
822
823 // Reject the second - send "deny" option_id directly since Deny is now a button
824 tool_call_auth_2
825 .response
826 .send(acp::PermissionOptionId::new("deny"))
827 .unwrap();
828 cx.run_until_parked();
829
830 let completion = fake_model.pending_completions().pop().unwrap();
831 let message = completion.messages.last().unwrap();
832 assert_eq!(
833 message.content,
834 vec![
835 language_model::MessageContent::ToolResult(LanguageModelToolResult {
836 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
837 tool_name: ToolRequiringPermission::NAME.into(),
838 is_error: false,
839 content: "Allowed".into(),
840 output: Some("Allowed".into())
841 }),
842 language_model::MessageContent::ToolResult(LanguageModelToolResult {
843 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
844 tool_name: ToolRequiringPermission::NAME.into(),
845 is_error: true,
846 content: "Permission to run tool denied by user".into(),
847 output: Some("Permission to run tool denied by user".into())
848 })
849 ]
850 );
851
852 // Simulate yet another tool call.
853 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
854 LanguageModelToolUse {
855 id: "tool_id_3".into(),
856 name: ToolRequiringPermission::NAME.into(),
857 raw_input: "{}".into(),
858 input: json!({}),
859 is_input_complete: true,
860 thought_signature: None,
861 },
862 ));
863 fake_model.end_last_completion_stream();
864
865 // Respond by always allowing tools - send transformed option_id
866 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
867 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
868 tool_call_auth_3
869 .response
870 .send(acp::PermissionOptionId::new(
871 "always_allow:tool_requiring_permission",
872 ))
873 .unwrap();
874 cx.run_until_parked();
875 let completion = fake_model.pending_completions().pop().unwrap();
876 let message = completion.messages.last().unwrap();
877 assert_eq!(
878 message.content,
879 vec![language_model::MessageContent::ToolResult(
880 LanguageModelToolResult {
881 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
882 tool_name: ToolRequiringPermission::NAME.into(),
883 is_error: false,
884 content: "Allowed".into(),
885 output: Some("Allowed".into())
886 }
887 )]
888 );
889
890 // Simulate a final tool call, ensuring we don't trigger authorization.
891 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
892 LanguageModelToolUse {
893 id: "tool_id_4".into(),
894 name: ToolRequiringPermission::NAME.into(),
895 raw_input: "{}".into(),
896 input: json!({}),
897 is_input_complete: true,
898 thought_signature: None,
899 },
900 ));
901 fake_model.end_last_completion_stream();
902 cx.run_until_parked();
903 let completion = fake_model.pending_completions().pop().unwrap();
904 let message = completion.messages.last().unwrap();
905 assert_eq!(
906 message.content,
907 vec![language_model::MessageContent::ToolResult(
908 LanguageModelToolResult {
909 tool_use_id: "tool_id_4".into(),
910 tool_name: ToolRequiringPermission::NAME.into(),
911 is_error: false,
912 content: "Allowed".into(),
913 output: Some("Allowed".into())
914 }
915 )]
916 );
917}
918
919#[gpui::test]
920async fn test_tool_hallucination(cx: &mut TestAppContext) {
921 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
922 let fake_model = model.as_fake();
923
924 let mut events = thread
925 .update(cx, |thread, cx| {
926 thread.send(UserMessageId::new(), ["abc"], cx)
927 })
928 .unwrap();
929 cx.run_until_parked();
930 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
931 LanguageModelToolUse {
932 id: "tool_id_1".into(),
933 name: "nonexistent_tool".into(),
934 raw_input: "{}".into(),
935 input: json!({}),
936 is_input_complete: true,
937 thought_signature: None,
938 },
939 ));
940 fake_model.end_last_completion_stream();
941
942 let tool_call = expect_tool_call(&mut events).await;
943 assert_eq!(tool_call.title, "nonexistent_tool");
944 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
945 let update = expect_tool_call_update_fields(&mut events).await;
946 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
947}
948
949async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
950 let event = events
951 .next()
952 .await
953 .expect("no tool call authorization event received")
954 .unwrap();
955 match event {
956 ThreadEvent::ToolCall(tool_call) => tool_call,
957 event => {
958 panic!("Unexpected event {event:?}");
959 }
960 }
961}
962
963async fn expect_tool_call_update_fields(
964 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
965) -> acp::ToolCallUpdate {
966 let event = events
967 .next()
968 .await
969 .expect("no tool call authorization event received")
970 .unwrap();
971 match event {
972 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
973 event => {
974 panic!("Unexpected event {event:?}");
975 }
976 }
977}
978
979async fn next_tool_call_authorization(
980 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
981) -> ToolCallAuthorization {
982 loop {
983 let event = events
984 .next()
985 .await
986 .expect("no tool call authorization event received")
987 .unwrap();
988 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
989 let permission_kinds = tool_call_authorization
990 .options
991 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
992 .map(|option| option.kind);
993 let allow_once = tool_call_authorization
994 .options
995 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
996 .map(|option| option.kind);
997
998 assert_eq!(
999 permission_kinds,
1000 Some(acp::PermissionOptionKind::AllowAlways)
1001 );
1002 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
1003 return tool_call_authorization;
1004 }
1005 }
1006}
1007
1008#[test]
1009fn test_permission_options_terminal_with_pattern() {
1010 let permission_options = ToolPermissionContext::new(
1011 TerminalTool::NAME,
1012 vec!["cargo build --release".to_string()],
1013 )
1014 .build_permission_options();
1015
1016 let PermissionOptions::Dropdown(choices) = permission_options else {
1017 panic!("Expected dropdown permission options");
1018 };
1019
1020 assert_eq!(choices.len(), 3);
1021 let labels: Vec<&str> = choices
1022 .iter()
1023 .map(|choice| choice.allow.name.as_ref())
1024 .collect();
1025 assert!(labels.contains(&"Always for terminal"));
1026 assert!(labels.contains(&"Always for `cargo build` commands"));
1027 assert!(labels.contains(&"Only this time"));
1028}
1029
1030#[test]
1031fn test_permission_options_terminal_command_with_flag_second_token() {
1032 let permission_options =
1033 ToolPermissionContext::new(TerminalTool::NAME, vec!["ls -la".to_string()])
1034 .build_permission_options();
1035
1036 let PermissionOptions::Dropdown(choices) = permission_options else {
1037 panic!("Expected dropdown permission options");
1038 };
1039
1040 assert_eq!(choices.len(), 3);
1041 let labels: Vec<&str> = choices
1042 .iter()
1043 .map(|choice| choice.allow.name.as_ref())
1044 .collect();
1045 assert!(labels.contains(&"Always for terminal"));
1046 assert!(labels.contains(&"Always for `ls` commands"));
1047 assert!(labels.contains(&"Only this time"));
1048}
1049
1050#[test]
1051fn test_permission_options_terminal_single_word_command() {
1052 let permission_options =
1053 ToolPermissionContext::new(TerminalTool::NAME, vec!["whoami".to_string()])
1054 .build_permission_options();
1055
1056 let PermissionOptions::Dropdown(choices) = permission_options else {
1057 panic!("Expected dropdown permission options");
1058 };
1059
1060 assert_eq!(choices.len(), 3);
1061 let labels: Vec<&str> = choices
1062 .iter()
1063 .map(|choice| choice.allow.name.as_ref())
1064 .collect();
1065 assert!(labels.contains(&"Always for terminal"));
1066 assert!(labels.contains(&"Always for `whoami` commands"));
1067 assert!(labels.contains(&"Only this time"));
1068}
1069
1070#[test]
1071fn test_permission_options_edit_file_with_path_pattern() {
1072 let permission_options =
1073 ToolPermissionContext::new(EditFileTool::NAME, vec!["src/main.rs".to_string()])
1074 .build_permission_options();
1075
1076 let PermissionOptions::Dropdown(choices) = permission_options else {
1077 panic!("Expected dropdown permission options");
1078 };
1079
1080 let labels: Vec<&str> = choices
1081 .iter()
1082 .map(|choice| choice.allow.name.as_ref())
1083 .collect();
1084 assert!(labels.contains(&"Always for edit file"));
1085 assert!(labels.contains(&"Always for `src/`"));
1086}
1087
1088#[test]
1089fn test_permission_options_fetch_with_domain_pattern() {
1090 let permission_options =
1091 ToolPermissionContext::new(FetchTool::NAME, vec!["https://docs.rs/gpui".to_string()])
1092 .build_permission_options();
1093
1094 let PermissionOptions::Dropdown(choices) = permission_options else {
1095 panic!("Expected dropdown permission options");
1096 };
1097
1098 let labels: Vec<&str> = choices
1099 .iter()
1100 .map(|choice| choice.allow.name.as_ref())
1101 .collect();
1102 assert!(labels.contains(&"Always for fetch"));
1103 assert!(labels.contains(&"Always for `docs.rs`"));
1104}
1105
1106#[test]
1107fn test_permission_options_without_pattern() {
1108 let permission_options = ToolPermissionContext::new(
1109 TerminalTool::NAME,
1110 vec!["./deploy.sh --production".to_string()],
1111 )
1112 .build_permission_options();
1113
1114 let PermissionOptions::Dropdown(choices) = permission_options else {
1115 panic!("Expected dropdown permission options");
1116 };
1117
1118 assert_eq!(choices.len(), 2);
1119 let labels: Vec<&str> = choices
1120 .iter()
1121 .map(|choice| choice.allow.name.as_ref())
1122 .collect();
1123 assert!(labels.contains(&"Always for terminal"));
1124 assert!(labels.contains(&"Only this time"));
1125 assert!(!labels.iter().any(|label| label.contains("commands")));
1126}
1127
1128#[test]
1129fn test_permission_options_symlink_target_are_flat_once_only() {
1130 let permission_options =
1131 ToolPermissionContext::symlink_target(EditFileTool::NAME, vec!["/outside/file.txt".into()])
1132 .build_permission_options();
1133
1134 let PermissionOptions::Flat(options) = permission_options else {
1135 panic!("Expected flat permission options for symlink target authorization");
1136 };
1137
1138 assert_eq!(options.len(), 2);
1139 assert!(options.iter().any(|option| {
1140 option.option_id.0.as_ref() == "allow"
1141 && option.kind == acp::PermissionOptionKind::AllowOnce
1142 }));
1143 assert!(options.iter().any(|option| {
1144 option.option_id.0.as_ref() == "deny"
1145 && option.kind == acp::PermissionOptionKind::RejectOnce
1146 }));
1147}
1148
1149#[test]
1150fn test_permission_option_ids_for_terminal() {
1151 let permission_options = ToolPermissionContext::new(
1152 TerminalTool::NAME,
1153 vec!["cargo build --release".to_string()],
1154 )
1155 .build_permission_options();
1156
1157 let PermissionOptions::Dropdown(choices) = permission_options else {
1158 panic!("Expected dropdown permission options");
1159 };
1160
1161 let allow_ids: Vec<String> = choices
1162 .iter()
1163 .map(|choice| choice.allow.option_id.0.to_string())
1164 .collect();
1165 let deny_ids: Vec<String> = choices
1166 .iter()
1167 .map(|choice| choice.deny.option_id.0.to_string())
1168 .collect();
1169
1170 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1171 assert!(allow_ids.contains(&"allow".to_string()));
1172 assert!(
1173 allow_ids
1174 .iter()
1175 .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
1176 "Missing allow pattern option"
1177 );
1178
1179 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1180 assert!(deny_ids.contains(&"deny".to_string()));
1181 assert!(
1182 deny_ids
1183 .iter()
1184 .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
1185 "Missing deny pattern option"
1186 );
1187}
1188
1189#[gpui::test]
1190#[cfg_attr(not(feature = "e2e"), ignore)]
1191async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1192 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1193
1194 // Test concurrent tool calls with different delay times
1195 let events = thread
1196 .update(cx, |thread, cx| {
1197 thread.add_tool(DelayTool);
1198 thread.send(
1199 UserMessageId::new(),
1200 [
1201 "Call the delay tool twice in the same message.",
1202 "Once with 100ms. Once with 300ms.",
1203 "When both timers are complete, describe the outputs.",
1204 ],
1205 cx,
1206 )
1207 })
1208 .unwrap()
1209 .collect()
1210 .await;
1211
1212 let stop_reasons = stop_events(events);
1213 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1214
1215 thread.update(cx, |thread, _cx| {
1216 let last_message = thread.last_message().unwrap();
1217 let agent_message = last_message.as_agent_message().unwrap();
1218 let text = agent_message
1219 .content
1220 .iter()
1221 .filter_map(|content| {
1222 if let AgentMessageContent::Text(text) = content {
1223 Some(text.as_str())
1224 } else {
1225 None
1226 }
1227 })
1228 .collect::<String>();
1229
1230 assert!(text.contains("Ding"));
1231 });
1232}
1233
1234#[gpui::test]
1235async fn test_profiles(cx: &mut TestAppContext) {
1236 let ThreadTest {
1237 model, thread, fs, ..
1238 } = setup(cx, TestModel::Fake).await;
1239 let fake_model = model.as_fake();
1240
1241 thread.update(cx, |thread, _cx| {
1242 thread.add_tool(DelayTool);
1243 thread.add_tool(EchoTool);
1244 thread.add_tool(InfiniteTool);
1245 });
1246
1247 // Override profiles and wait for settings to be loaded.
1248 fs.insert_file(
1249 paths::settings_file(),
1250 json!({
1251 "agent": {
1252 "profiles": {
1253 "test-1": {
1254 "name": "Test Profile 1",
1255 "tools": {
1256 EchoTool::NAME: true,
1257 DelayTool::NAME: true,
1258 }
1259 },
1260 "test-2": {
1261 "name": "Test Profile 2",
1262 "tools": {
1263 InfiniteTool::NAME: true,
1264 }
1265 }
1266 }
1267 }
1268 })
1269 .to_string()
1270 .into_bytes(),
1271 )
1272 .await;
1273 cx.run_until_parked();
1274
1275 // Test that test-1 profile (default) has echo and delay tools
1276 thread
1277 .update(cx, |thread, cx| {
1278 thread.set_profile(AgentProfileId("test-1".into()), cx);
1279 thread.send(UserMessageId::new(), ["test"], cx)
1280 })
1281 .unwrap();
1282 cx.run_until_parked();
1283
1284 let mut pending_completions = fake_model.pending_completions();
1285 assert_eq!(pending_completions.len(), 1);
1286 let completion = pending_completions.pop().unwrap();
1287 let tool_names: Vec<String> = completion
1288 .tools
1289 .iter()
1290 .map(|tool| tool.name.clone())
1291 .collect();
1292 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1293 fake_model.end_last_completion_stream();
1294
1295 // Switch to test-2 profile, and verify that it has only the infinite tool.
1296 thread
1297 .update(cx, |thread, cx| {
1298 thread.set_profile(AgentProfileId("test-2".into()), cx);
1299 thread.send(UserMessageId::new(), ["test2"], cx)
1300 })
1301 .unwrap();
1302 cx.run_until_parked();
1303 let mut pending_completions = fake_model.pending_completions();
1304 assert_eq!(pending_completions.len(), 1);
1305 let completion = pending_completions.pop().unwrap();
1306 let tool_names: Vec<String> = completion
1307 .tools
1308 .iter()
1309 .map(|tool| tool.name.clone())
1310 .collect();
1311 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1312}
1313
1314#[gpui::test]
1315async fn test_mcp_tools(cx: &mut TestAppContext) {
1316 let ThreadTest {
1317 model,
1318 thread,
1319 context_server_store,
1320 fs,
1321 ..
1322 } = setup(cx, TestModel::Fake).await;
1323 let fake_model = model.as_fake();
1324
1325 // Override profiles and wait for settings to be loaded.
1326 fs.insert_file(
1327 paths::settings_file(),
1328 json!({
1329 "agent": {
1330 "tool_permissions": { "default": "allow" },
1331 "profiles": {
1332 "test": {
1333 "name": "Test Profile",
1334 "enable_all_context_servers": true,
1335 "tools": {
1336 EchoTool::NAME: true,
1337 }
1338 },
1339 }
1340 }
1341 })
1342 .to_string()
1343 .into_bytes(),
1344 )
1345 .await;
1346 cx.run_until_parked();
1347 thread.update(cx, |thread, cx| {
1348 thread.set_profile(AgentProfileId("test".into()), cx)
1349 });
1350
1351 let mut mcp_tool_calls = setup_context_server(
1352 "test_server",
1353 vec![context_server::types::Tool {
1354 name: "echo".into(),
1355 description: None,
1356 input_schema: serde_json::to_value(EchoTool::input_schema(
1357 LanguageModelToolSchemaFormat::JsonSchema,
1358 ))
1359 .unwrap(),
1360 output_schema: None,
1361 annotations: None,
1362 }],
1363 &context_server_store,
1364 cx,
1365 );
1366
1367 let events = thread.update(cx, |thread, cx| {
1368 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1369 });
1370 cx.run_until_parked();
1371
1372 // Simulate the model calling the MCP tool.
1373 let completion = fake_model.pending_completions().pop().unwrap();
1374 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1375 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1376 LanguageModelToolUse {
1377 id: "tool_1".into(),
1378 name: "echo".into(),
1379 raw_input: json!({"text": "test"}).to_string(),
1380 input: json!({"text": "test"}),
1381 is_input_complete: true,
1382 thought_signature: None,
1383 },
1384 ));
1385 fake_model.end_last_completion_stream();
1386 cx.run_until_parked();
1387
1388 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1389 assert_eq!(tool_call_params.name, "echo");
1390 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1391 tool_call_response
1392 .send(context_server::types::CallToolResponse {
1393 content: vec![context_server::types::ToolResponseContent::Text {
1394 text: "test".into(),
1395 }],
1396 is_error: None,
1397 meta: None,
1398 structured_content: None,
1399 })
1400 .unwrap();
1401 cx.run_until_parked();
1402
1403 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1404 fake_model.send_last_completion_stream_text_chunk("Done!");
1405 fake_model.end_last_completion_stream();
1406 events.collect::<Vec<_>>().await;
1407
1408 // Send again after adding the echo tool, ensuring the name collision is resolved.
1409 let events = thread.update(cx, |thread, cx| {
1410 thread.add_tool(EchoTool);
1411 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1412 });
1413 cx.run_until_parked();
1414 let completion = fake_model.pending_completions().pop().unwrap();
1415 assert_eq!(
1416 tool_names_for_completion(&completion),
1417 vec!["echo", "test_server_echo"]
1418 );
1419 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1420 LanguageModelToolUse {
1421 id: "tool_2".into(),
1422 name: "test_server_echo".into(),
1423 raw_input: json!({"text": "mcp"}).to_string(),
1424 input: json!({"text": "mcp"}),
1425 is_input_complete: true,
1426 thought_signature: None,
1427 },
1428 ));
1429 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1430 LanguageModelToolUse {
1431 id: "tool_3".into(),
1432 name: "echo".into(),
1433 raw_input: json!({"text": "native"}).to_string(),
1434 input: json!({"text": "native"}),
1435 is_input_complete: true,
1436 thought_signature: None,
1437 },
1438 ));
1439 fake_model.end_last_completion_stream();
1440 cx.run_until_parked();
1441
1442 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1443 assert_eq!(tool_call_params.name, "echo");
1444 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1445 tool_call_response
1446 .send(context_server::types::CallToolResponse {
1447 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1448 is_error: None,
1449 meta: None,
1450 structured_content: None,
1451 })
1452 .unwrap();
1453 cx.run_until_parked();
1454
1455 // Ensure the tool results were inserted with the correct names.
1456 let completion = fake_model.pending_completions().pop().unwrap();
1457 assert_eq!(
1458 completion.messages.last().unwrap().content,
1459 vec![
1460 MessageContent::ToolResult(LanguageModelToolResult {
1461 tool_use_id: "tool_3".into(),
1462 tool_name: "echo".into(),
1463 is_error: false,
1464 content: "native".into(),
1465 output: Some("native".into()),
1466 },),
1467 MessageContent::ToolResult(LanguageModelToolResult {
1468 tool_use_id: "tool_2".into(),
1469 tool_name: "test_server_echo".into(),
1470 is_error: false,
1471 content: "mcp".into(),
1472 output: Some("mcp".into()),
1473 },),
1474 ]
1475 );
1476 fake_model.end_last_completion_stream();
1477 events.collect::<Vec<_>>().await;
1478}
1479
1480#[gpui::test]
1481async fn test_mcp_tool_result_displayed_when_server_disconnected(cx: &mut TestAppContext) {
1482 let ThreadTest {
1483 model,
1484 thread,
1485 context_server_store,
1486 fs,
1487 ..
1488 } = setup(cx, TestModel::Fake).await;
1489 let fake_model = model.as_fake();
1490
1491 // Setup settings to allow MCP tools
1492 fs.insert_file(
1493 paths::settings_file(),
1494 json!({
1495 "agent": {
1496 "always_allow_tool_actions": true,
1497 "profiles": {
1498 "test": {
1499 "name": "Test Profile",
1500 "enable_all_context_servers": true,
1501 "tools": {}
1502 },
1503 }
1504 }
1505 })
1506 .to_string()
1507 .into_bytes(),
1508 )
1509 .await;
1510 cx.run_until_parked();
1511 thread.update(cx, |thread, cx| {
1512 thread.set_profile(AgentProfileId("test".into()), cx)
1513 });
1514
1515 // Setup a context server with a tool
1516 let mut mcp_tool_calls = setup_context_server(
1517 "github_server",
1518 vec![context_server::types::Tool {
1519 name: "issue_read".into(),
1520 description: Some("Read a GitHub issue".into()),
1521 input_schema: json!({
1522 "type": "object",
1523 "properties": {
1524 "issue_url": { "type": "string" }
1525 }
1526 }),
1527 output_schema: None,
1528 annotations: None,
1529 }],
1530 &context_server_store,
1531 cx,
1532 );
1533
1534 // Send a message and have the model call the MCP tool
1535 let events = thread.update(cx, |thread, cx| {
1536 thread
1537 .send(UserMessageId::new(), ["Read issue #47404"], cx)
1538 .unwrap()
1539 });
1540 cx.run_until_parked();
1541
1542 // Verify the MCP tool is available to the model
1543 let completion = fake_model.pending_completions().pop().unwrap();
1544 assert_eq!(
1545 tool_names_for_completion(&completion),
1546 vec!["issue_read"],
1547 "MCP tool should be available"
1548 );
1549
1550 // Simulate the model calling the MCP tool
1551 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1552 LanguageModelToolUse {
1553 id: "tool_1".into(),
1554 name: "issue_read".into(),
1555 raw_input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"})
1556 .to_string(),
1557 input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"}),
1558 is_input_complete: true,
1559 thought_signature: None,
1560 },
1561 ));
1562 fake_model.end_last_completion_stream();
1563 cx.run_until_parked();
1564
1565 // The MCP server receives the tool call and responds with content
1566 let expected_tool_output = "Issue #47404: Tool call results are cleared upon app close";
1567 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1568 assert_eq!(tool_call_params.name, "issue_read");
1569 tool_call_response
1570 .send(context_server::types::CallToolResponse {
1571 content: vec![context_server::types::ToolResponseContent::Text {
1572 text: expected_tool_output.into(),
1573 }],
1574 is_error: None,
1575 meta: None,
1576 structured_content: None,
1577 })
1578 .unwrap();
1579 cx.run_until_parked();
1580
1581 // After tool completes, the model continues with a new completion request
1582 // that includes the tool results. We need to respond to this.
1583 let _completion = fake_model.pending_completions().pop().unwrap();
1584 fake_model.send_last_completion_stream_text_chunk("I found the issue!");
1585 fake_model
1586 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1587 fake_model.end_last_completion_stream();
1588 events.collect::<Vec<_>>().await;
1589
1590 // Verify the tool result is stored in the thread by checking the markdown output.
1591 // The tool result is in the first assistant message (not the last one, which is
1592 // the model's response after the tool completed).
1593 thread.update(cx, |thread, _cx| {
1594 let markdown = thread.to_markdown();
1595 assert!(
1596 markdown.contains("**Tool Result**: issue_read"),
1597 "Thread should contain tool result header"
1598 );
1599 assert!(
1600 markdown.contains(expected_tool_output),
1601 "Thread should contain tool output: {}",
1602 expected_tool_output
1603 );
1604 });
1605
1606 // Simulate app restart: disconnect the MCP server.
1607 // After restart, the MCP server won't be connected yet when the thread is replayed.
1608 context_server_store.update(cx, |store, cx| {
1609 let _ = store.stop_server(&ContextServerId("github_server".into()), cx);
1610 });
1611 cx.run_until_parked();
1612
1613 // Replay the thread (this is what happens when loading a saved thread)
1614 let mut replay_events = thread.update(cx, |thread, cx| thread.replay(cx));
1615
1616 let mut found_tool_call = None;
1617 let mut found_tool_call_update_with_output = None;
1618
1619 while let Some(event) = replay_events.next().await {
1620 let event = event.unwrap();
1621 match &event {
1622 ThreadEvent::ToolCall(tc) if tc.tool_call_id.to_string() == "tool_1" => {
1623 found_tool_call = Some(tc.clone());
1624 }
1625 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update))
1626 if update.tool_call_id.to_string() == "tool_1" =>
1627 {
1628 if update.fields.raw_output.is_some() {
1629 found_tool_call_update_with_output = Some(update.clone());
1630 }
1631 }
1632 _ => {}
1633 }
1634 }
1635
1636 // The tool call should be found
1637 assert!(
1638 found_tool_call.is_some(),
1639 "Tool call should be emitted during replay"
1640 );
1641
1642 assert!(
1643 found_tool_call_update_with_output.is_some(),
1644 "ToolCallUpdate with raw_output should be emitted even when MCP server is disconnected."
1645 );
1646
1647 let update = found_tool_call_update_with_output.unwrap();
1648 assert_eq!(
1649 update.fields.raw_output,
1650 Some(expected_tool_output.into()),
1651 "raw_output should contain the saved tool result"
1652 );
1653
1654 // Also verify the status is correct (completed, not failed)
1655 assert_eq!(
1656 update.fields.status,
1657 Some(acp::ToolCallStatus::Completed),
1658 "Tool call status should reflect the original completion status"
1659 );
1660}
1661
1662#[gpui::test]
1663async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1664 let ThreadTest {
1665 model,
1666 thread,
1667 context_server_store,
1668 fs,
1669 ..
1670 } = setup(cx, TestModel::Fake).await;
1671 let fake_model = model.as_fake();
1672
1673 // Set up a profile with all tools enabled
1674 fs.insert_file(
1675 paths::settings_file(),
1676 json!({
1677 "agent": {
1678 "profiles": {
1679 "test": {
1680 "name": "Test Profile",
1681 "enable_all_context_servers": true,
1682 "tools": {
1683 EchoTool::NAME: true,
1684 DelayTool::NAME: true,
1685 WordListTool::NAME: true,
1686 ToolRequiringPermission::NAME: true,
1687 InfiniteTool::NAME: true,
1688 }
1689 },
1690 }
1691 }
1692 })
1693 .to_string()
1694 .into_bytes(),
1695 )
1696 .await;
1697 cx.run_until_parked();
1698
1699 thread.update(cx, |thread, cx| {
1700 thread.set_profile(AgentProfileId("test".into()), cx);
1701 thread.add_tool(EchoTool);
1702 thread.add_tool(DelayTool);
1703 thread.add_tool(WordListTool);
1704 thread.add_tool(ToolRequiringPermission);
1705 thread.add_tool(InfiniteTool);
1706 });
1707
1708 // Set up multiple context servers with some overlapping tool names
1709 let _server1_calls = setup_context_server(
1710 "xxx",
1711 vec![
1712 context_server::types::Tool {
1713 name: "echo".into(), // Conflicts with native EchoTool
1714 description: None,
1715 input_schema: serde_json::to_value(EchoTool::input_schema(
1716 LanguageModelToolSchemaFormat::JsonSchema,
1717 ))
1718 .unwrap(),
1719 output_schema: None,
1720 annotations: None,
1721 },
1722 context_server::types::Tool {
1723 name: "unique_tool_1".into(),
1724 description: None,
1725 input_schema: json!({"type": "object", "properties": {}}),
1726 output_schema: None,
1727 annotations: None,
1728 },
1729 ],
1730 &context_server_store,
1731 cx,
1732 );
1733
1734 let _server2_calls = setup_context_server(
1735 "yyy",
1736 vec![
1737 context_server::types::Tool {
1738 name: "echo".into(), // Also conflicts with native EchoTool
1739 description: None,
1740 input_schema: serde_json::to_value(EchoTool::input_schema(
1741 LanguageModelToolSchemaFormat::JsonSchema,
1742 ))
1743 .unwrap(),
1744 output_schema: None,
1745 annotations: None,
1746 },
1747 context_server::types::Tool {
1748 name: "unique_tool_2".into(),
1749 description: None,
1750 input_schema: json!({"type": "object", "properties": {}}),
1751 output_schema: None,
1752 annotations: None,
1753 },
1754 context_server::types::Tool {
1755 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1756 description: None,
1757 input_schema: json!({"type": "object", "properties": {}}),
1758 output_schema: None,
1759 annotations: None,
1760 },
1761 context_server::types::Tool {
1762 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1763 description: None,
1764 input_schema: json!({"type": "object", "properties": {}}),
1765 output_schema: None,
1766 annotations: None,
1767 },
1768 ],
1769 &context_server_store,
1770 cx,
1771 );
1772 let _server3_calls = setup_context_server(
1773 "zzz",
1774 vec![
1775 context_server::types::Tool {
1776 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1777 description: None,
1778 input_schema: json!({"type": "object", "properties": {}}),
1779 output_schema: None,
1780 annotations: None,
1781 },
1782 context_server::types::Tool {
1783 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1784 description: None,
1785 input_schema: json!({"type": "object", "properties": {}}),
1786 output_schema: None,
1787 annotations: None,
1788 },
1789 context_server::types::Tool {
1790 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1791 description: None,
1792 input_schema: json!({"type": "object", "properties": {}}),
1793 output_schema: None,
1794 annotations: None,
1795 },
1796 ],
1797 &context_server_store,
1798 cx,
1799 );
1800
1801 // Server with spaces in name - tests snake_case conversion for API compatibility
1802 let _server4_calls = setup_context_server(
1803 "Azure DevOps",
1804 vec![context_server::types::Tool {
1805 name: "echo".into(), // Also conflicts - will be disambiguated as azure_dev_ops_echo
1806 description: None,
1807 input_schema: serde_json::to_value(EchoTool::input_schema(
1808 LanguageModelToolSchemaFormat::JsonSchema,
1809 ))
1810 .unwrap(),
1811 output_schema: None,
1812 annotations: None,
1813 }],
1814 &context_server_store,
1815 cx,
1816 );
1817
1818 thread
1819 .update(cx, |thread, cx| {
1820 thread.send(UserMessageId::new(), ["Go"], cx)
1821 })
1822 .unwrap();
1823 cx.run_until_parked();
1824 let completion = fake_model.pending_completions().pop().unwrap();
1825 assert_eq!(
1826 tool_names_for_completion(&completion),
1827 vec![
1828 "azure_dev_ops_echo",
1829 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1830 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1831 "delay",
1832 "echo",
1833 "infinite",
1834 "tool_requiring_permission",
1835 "unique_tool_1",
1836 "unique_tool_2",
1837 "word_list",
1838 "xxx_echo",
1839 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1840 "yyy_echo",
1841 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1842 ]
1843 );
1844}
1845
1846#[gpui::test]
1847#[cfg_attr(not(feature = "e2e"), ignore)]
1848async fn test_cancellation(cx: &mut TestAppContext) {
1849 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1850
1851 let mut events = thread
1852 .update(cx, |thread, cx| {
1853 thread.add_tool(InfiniteTool);
1854 thread.add_tool(EchoTool);
1855 thread.send(
1856 UserMessageId::new(),
1857 ["Call the echo tool, then call the infinite tool, then explain their output"],
1858 cx,
1859 )
1860 })
1861 .unwrap();
1862
1863 // Wait until both tools are called.
1864 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1865 let mut echo_id = None;
1866 let mut echo_completed = false;
1867 while let Some(event) = events.next().await {
1868 match event.unwrap() {
1869 ThreadEvent::ToolCall(tool_call) => {
1870 assert_eq!(tool_call.title, expected_tools.remove(0));
1871 if tool_call.title == "Echo" {
1872 echo_id = Some(tool_call.tool_call_id);
1873 }
1874 }
1875 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1876 acp::ToolCallUpdate {
1877 tool_call_id,
1878 fields:
1879 acp::ToolCallUpdateFields {
1880 status: Some(acp::ToolCallStatus::Completed),
1881 ..
1882 },
1883 ..
1884 },
1885 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1886 echo_completed = true;
1887 }
1888 _ => {}
1889 }
1890
1891 if expected_tools.is_empty() && echo_completed {
1892 break;
1893 }
1894 }
1895
1896 // Cancel the current send and ensure that the event stream is closed, even
1897 // if one of the tools is still running.
1898 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1899 let events = events.collect::<Vec<_>>().await;
1900 let last_event = events.last();
1901 assert!(
1902 matches!(
1903 last_event,
1904 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1905 ),
1906 "unexpected event {last_event:?}"
1907 );
1908
1909 // Ensure we can still send a new message after cancellation.
1910 let events = thread
1911 .update(cx, |thread, cx| {
1912 thread.send(
1913 UserMessageId::new(),
1914 ["Testing: reply with 'Hello' then stop."],
1915 cx,
1916 )
1917 })
1918 .unwrap()
1919 .collect::<Vec<_>>()
1920 .await;
1921 thread.update(cx, |thread, _cx| {
1922 let message = thread.last_message().unwrap();
1923 let agent_message = message.as_agent_message().unwrap();
1924 assert_eq!(
1925 agent_message.content,
1926 vec![AgentMessageContent::Text("Hello".to_string())]
1927 );
1928 });
1929 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1930}
1931
1932#[gpui::test]
1933async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1934 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1935 always_allow_tools(cx);
1936 let fake_model = model.as_fake();
1937
1938 let environment = Rc::new(cx.update(|cx| {
1939 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1940 }));
1941 let handle = environment.terminal_handle.clone().unwrap();
1942
1943 let mut events = thread
1944 .update(cx, |thread, cx| {
1945 thread.add_tool(crate::TerminalTool::new(
1946 thread.project().clone(),
1947 environment,
1948 ));
1949 thread.send(UserMessageId::new(), ["run a command"], cx)
1950 })
1951 .unwrap();
1952
1953 cx.run_until_parked();
1954
1955 // Simulate the model calling the terminal tool
1956 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1957 LanguageModelToolUse {
1958 id: "terminal_tool_1".into(),
1959 name: TerminalTool::NAME.into(),
1960 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1961 input: json!({"command": "sleep 1000", "cd": "."}),
1962 is_input_complete: true,
1963 thought_signature: None,
1964 },
1965 ));
1966 fake_model.end_last_completion_stream();
1967
1968 // Wait for the terminal tool to start running
1969 wait_for_terminal_tool_started(&mut events, cx).await;
1970
1971 // Cancel the thread while the terminal is running
1972 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1973
1974 // Collect remaining events, driving the executor to let cancellation complete
1975 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1976
1977 // Verify the terminal was killed
1978 assert!(
1979 handle.was_killed(),
1980 "expected terminal handle to be killed on cancellation"
1981 );
1982
1983 // Verify we got a cancellation stop event
1984 assert_eq!(
1985 stop_events(remaining_events),
1986 vec![acp::StopReason::Cancelled],
1987 );
1988
1989 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1990 thread.update(cx, |thread, _cx| {
1991 let message = thread.last_message().unwrap();
1992 let agent_message = message.as_agent_message().unwrap();
1993
1994 let tool_use = agent_message
1995 .content
1996 .iter()
1997 .find_map(|content| match content {
1998 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1999 _ => None,
2000 })
2001 .expect("expected tool use in agent message");
2002
2003 let tool_result = agent_message
2004 .tool_results
2005 .get(&tool_use.id)
2006 .expect("expected tool result");
2007
2008 let result_text = match &tool_result.content {
2009 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2010 _ => panic!("expected text content in tool result"),
2011 };
2012
2013 // "partial output" comes from FakeTerminalHandle's output field
2014 assert!(
2015 result_text.contains("partial output"),
2016 "expected tool result to contain terminal output, got: {result_text}"
2017 );
2018 // Match the actual format from process_content in terminal_tool.rs
2019 assert!(
2020 result_text.contains("The user stopped this command"),
2021 "expected tool result to indicate user stopped, got: {result_text}"
2022 );
2023 });
2024
2025 // Verify we can send a new message after cancellation
2026 verify_thread_recovery(&thread, &fake_model, cx).await;
2027}
2028
2029#[gpui::test]
2030async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
2031 // This test verifies that tools which properly handle cancellation via
2032 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
2033 // to cancellation and report that they were cancelled.
2034 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2035 always_allow_tools(cx);
2036 let fake_model = model.as_fake();
2037
2038 let (tool, was_cancelled) = CancellationAwareTool::new();
2039
2040 let mut events = thread
2041 .update(cx, |thread, cx| {
2042 thread.add_tool(tool);
2043 thread.send(
2044 UserMessageId::new(),
2045 ["call the cancellation aware tool"],
2046 cx,
2047 )
2048 })
2049 .unwrap();
2050
2051 cx.run_until_parked();
2052
2053 // Simulate the model calling the cancellation-aware tool
2054 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2055 LanguageModelToolUse {
2056 id: "cancellation_aware_1".into(),
2057 name: "cancellation_aware".into(),
2058 raw_input: r#"{}"#.into(),
2059 input: json!({}),
2060 is_input_complete: true,
2061 thought_signature: None,
2062 },
2063 ));
2064 fake_model.end_last_completion_stream();
2065
2066 cx.run_until_parked();
2067
2068 // Wait for the tool call to be reported
2069 let mut tool_started = false;
2070 let deadline = cx.executor().num_cpus() * 100;
2071 for _ in 0..deadline {
2072 cx.run_until_parked();
2073
2074 while let Some(Some(event)) = events.next().now_or_never() {
2075 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
2076 if tool_call.title == "Cancellation Aware Tool" {
2077 tool_started = true;
2078 break;
2079 }
2080 }
2081 }
2082
2083 if tool_started {
2084 break;
2085 }
2086
2087 cx.background_executor
2088 .timer(Duration::from_millis(10))
2089 .await;
2090 }
2091 assert!(tool_started, "expected cancellation aware tool to start");
2092
2093 // Cancel the thread and wait for it to complete
2094 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
2095
2096 // The cancel task should complete promptly because the tool handles cancellation
2097 let timeout = cx.background_executor.timer(Duration::from_secs(5));
2098 futures::select! {
2099 _ = cancel_task.fuse() => {}
2100 _ = timeout.fuse() => {
2101 panic!("cancel task timed out - tool did not respond to cancellation");
2102 }
2103 }
2104
2105 // Verify the tool detected cancellation via its flag
2106 assert!(
2107 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
2108 "tool should have detected cancellation via event_stream.cancelled_by_user()"
2109 );
2110
2111 // Collect remaining events
2112 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2113
2114 // Verify we got a cancellation stop event
2115 assert_eq!(
2116 stop_events(remaining_events),
2117 vec![acp::StopReason::Cancelled],
2118 );
2119
2120 // Verify we can send a new message after cancellation
2121 verify_thread_recovery(&thread, &fake_model, cx).await;
2122}
2123
2124/// Helper to verify thread can recover after cancellation by sending a simple message.
2125async fn verify_thread_recovery(
2126 thread: &Entity<Thread>,
2127 fake_model: &FakeLanguageModel,
2128 cx: &mut TestAppContext,
2129) {
2130 let events = thread
2131 .update(cx, |thread, cx| {
2132 thread.send(
2133 UserMessageId::new(),
2134 ["Testing: reply with 'Hello' then stop."],
2135 cx,
2136 )
2137 })
2138 .unwrap();
2139 cx.run_until_parked();
2140 fake_model.send_last_completion_stream_text_chunk("Hello");
2141 fake_model
2142 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2143 fake_model.end_last_completion_stream();
2144
2145 let events = events.collect::<Vec<_>>().await;
2146 thread.update(cx, |thread, _cx| {
2147 let message = thread.last_message().unwrap();
2148 let agent_message = message.as_agent_message().unwrap();
2149 assert_eq!(
2150 agent_message.content,
2151 vec![AgentMessageContent::Text("Hello".to_string())]
2152 );
2153 });
2154 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
2155}
2156
2157/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
2158async fn wait_for_terminal_tool_started(
2159 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2160 cx: &mut TestAppContext,
2161) {
2162 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
2163 for _ in 0..deadline {
2164 cx.run_until_parked();
2165
2166 while let Some(Some(event)) = events.next().now_or_never() {
2167 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2168 update,
2169 ))) = &event
2170 {
2171 if update.fields.content.as_ref().is_some_and(|content| {
2172 content
2173 .iter()
2174 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2175 }) {
2176 return;
2177 }
2178 }
2179 }
2180
2181 cx.background_executor
2182 .timer(Duration::from_millis(10))
2183 .await;
2184 }
2185 panic!("terminal tool did not start within the expected time");
2186}
2187
2188/// Collects events until a Stop event is received, driving the executor to completion.
2189async fn collect_events_until_stop(
2190 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2191 cx: &mut TestAppContext,
2192) -> Vec<Result<ThreadEvent>> {
2193 let mut collected = Vec::new();
2194 let deadline = cx.executor().num_cpus() * 200;
2195
2196 for _ in 0..deadline {
2197 cx.executor().advance_clock(Duration::from_millis(10));
2198 cx.run_until_parked();
2199
2200 while let Some(Some(event)) = events.next().now_or_never() {
2201 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
2202 collected.push(event);
2203 if is_stop {
2204 return collected;
2205 }
2206 }
2207 }
2208 panic!(
2209 "did not receive Stop event within the expected time; collected {} events",
2210 collected.len()
2211 );
2212}
2213
2214#[gpui::test]
2215async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
2216 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2217 always_allow_tools(cx);
2218 let fake_model = model.as_fake();
2219
2220 let environment = Rc::new(cx.update(|cx| {
2221 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2222 }));
2223 let handle = environment.terminal_handle.clone().unwrap();
2224
2225 let message_id = UserMessageId::new();
2226 let mut events = thread
2227 .update(cx, |thread, cx| {
2228 thread.add_tool(crate::TerminalTool::new(
2229 thread.project().clone(),
2230 environment,
2231 ));
2232 thread.send(message_id.clone(), ["run a command"], cx)
2233 })
2234 .unwrap();
2235
2236 cx.run_until_parked();
2237
2238 // Simulate the model calling the terminal tool
2239 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2240 LanguageModelToolUse {
2241 id: "terminal_tool_1".into(),
2242 name: TerminalTool::NAME.into(),
2243 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2244 input: json!({"command": "sleep 1000", "cd": "."}),
2245 is_input_complete: true,
2246 thought_signature: None,
2247 },
2248 ));
2249 fake_model.end_last_completion_stream();
2250
2251 // Wait for the terminal tool to start running
2252 wait_for_terminal_tool_started(&mut events, cx).await;
2253
2254 // Truncate the thread while the terminal is running
2255 thread
2256 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2257 .unwrap();
2258
2259 // Drive the executor to let cancellation complete
2260 let _ = collect_events_until_stop(&mut events, cx).await;
2261
2262 // Verify the terminal was killed
2263 assert!(
2264 handle.was_killed(),
2265 "expected terminal handle to be killed on truncate"
2266 );
2267
2268 // Verify the thread is empty after truncation
2269 thread.update(cx, |thread, _cx| {
2270 assert_eq!(
2271 thread.to_markdown(),
2272 "",
2273 "expected thread to be empty after truncating the only message"
2274 );
2275 });
2276
2277 // Verify we can send a new message after truncation
2278 verify_thread_recovery(&thread, &fake_model, cx).await;
2279}
2280
2281#[gpui::test]
2282async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2283 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2284 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2285 always_allow_tools(cx);
2286 let fake_model = model.as_fake();
2287
2288 let environment = Rc::new(MultiTerminalEnvironment::new());
2289
2290 let mut events = thread
2291 .update(cx, |thread, cx| {
2292 thread.add_tool(crate::TerminalTool::new(
2293 thread.project().clone(),
2294 environment.clone(),
2295 ));
2296 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2297 })
2298 .unwrap();
2299
2300 cx.run_until_parked();
2301
2302 // Simulate the model calling two terminal tools
2303 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2304 LanguageModelToolUse {
2305 id: "terminal_tool_1".into(),
2306 name: TerminalTool::NAME.into(),
2307 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2308 input: json!({"command": "sleep 1000", "cd": "."}),
2309 is_input_complete: true,
2310 thought_signature: None,
2311 },
2312 ));
2313 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2314 LanguageModelToolUse {
2315 id: "terminal_tool_2".into(),
2316 name: TerminalTool::NAME.into(),
2317 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2318 input: json!({"command": "sleep 2000", "cd": "."}),
2319 is_input_complete: true,
2320 thought_signature: None,
2321 },
2322 ));
2323 fake_model.end_last_completion_stream();
2324
2325 // Wait for both terminal tools to start by counting terminal content updates
2326 let mut terminals_started = 0;
2327 let deadline = cx.executor().num_cpus() * 100;
2328 for _ in 0..deadline {
2329 cx.run_until_parked();
2330
2331 while let Some(Some(event)) = events.next().now_or_never() {
2332 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2333 update,
2334 ))) = &event
2335 {
2336 if update.fields.content.as_ref().is_some_and(|content| {
2337 content
2338 .iter()
2339 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2340 }) {
2341 terminals_started += 1;
2342 if terminals_started >= 2 {
2343 break;
2344 }
2345 }
2346 }
2347 }
2348 if terminals_started >= 2 {
2349 break;
2350 }
2351
2352 cx.background_executor
2353 .timer(Duration::from_millis(10))
2354 .await;
2355 }
2356 assert!(
2357 terminals_started >= 2,
2358 "expected 2 terminal tools to start, got {terminals_started}"
2359 );
2360
2361 // Cancel the thread while both terminals are running
2362 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2363
2364 // Collect remaining events
2365 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2366
2367 // Verify both terminal handles were killed
2368 let handles = environment.handles();
2369 assert_eq!(
2370 handles.len(),
2371 2,
2372 "expected 2 terminal handles to be created"
2373 );
2374 assert!(
2375 handles[0].was_killed(),
2376 "expected first terminal handle to be killed on cancellation"
2377 );
2378 assert!(
2379 handles[1].was_killed(),
2380 "expected second terminal handle to be killed on cancellation"
2381 );
2382
2383 // Verify we got a cancellation stop event
2384 assert_eq!(
2385 stop_events(remaining_events),
2386 vec![acp::StopReason::Cancelled],
2387 );
2388}
2389
2390#[gpui::test]
2391async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2392 // Tests that clicking the stop button on the terminal card (as opposed to the main
2393 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2394 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2395 always_allow_tools(cx);
2396 let fake_model = model.as_fake();
2397
2398 let environment = Rc::new(cx.update(|cx| {
2399 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2400 }));
2401 let handle = environment.terminal_handle.clone().unwrap();
2402
2403 let mut events = thread
2404 .update(cx, |thread, cx| {
2405 thread.add_tool(crate::TerminalTool::new(
2406 thread.project().clone(),
2407 environment,
2408 ));
2409 thread.send(UserMessageId::new(), ["run a command"], cx)
2410 })
2411 .unwrap();
2412
2413 cx.run_until_parked();
2414
2415 // Simulate the model calling the terminal tool
2416 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2417 LanguageModelToolUse {
2418 id: "terminal_tool_1".into(),
2419 name: TerminalTool::NAME.into(),
2420 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2421 input: json!({"command": "sleep 1000", "cd": "."}),
2422 is_input_complete: true,
2423 thought_signature: None,
2424 },
2425 ));
2426 fake_model.end_last_completion_stream();
2427
2428 // Wait for the terminal tool to start running
2429 wait_for_terminal_tool_started(&mut events, cx).await;
2430
2431 // Simulate user clicking stop on the terminal card itself.
2432 // This sets the flag and signals exit (simulating what the real UI would do).
2433 handle.set_stopped_by_user(true);
2434 handle.killed.store(true, Ordering::SeqCst);
2435 handle.signal_exit();
2436
2437 // Wait for the tool to complete
2438 cx.run_until_parked();
2439
2440 // The thread continues after tool completion - simulate the model ending its turn
2441 fake_model
2442 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2443 fake_model.end_last_completion_stream();
2444
2445 // Collect remaining events
2446 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2447
2448 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2449 assert_eq!(
2450 stop_events(remaining_events),
2451 vec![acp::StopReason::EndTurn],
2452 );
2453
2454 // Verify the tool result indicates user stopped
2455 thread.update(cx, |thread, _cx| {
2456 let message = thread.last_message().unwrap();
2457 let agent_message = message.as_agent_message().unwrap();
2458
2459 let tool_use = agent_message
2460 .content
2461 .iter()
2462 .find_map(|content| match content {
2463 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2464 _ => None,
2465 })
2466 .expect("expected tool use in agent message");
2467
2468 let tool_result = agent_message
2469 .tool_results
2470 .get(&tool_use.id)
2471 .expect("expected tool result");
2472
2473 let result_text = match &tool_result.content {
2474 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2475 _ => panic!("expected text content in tool result"),
2476 };
2477
2478 assert!(
2479 result_text.contains("The user stopped this command"),
2480 "expected tool result to indicate user stopped, got: {result_text}"
2481 );
2482 });
2483}
2484
2485#[gpui::test]
2486async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2487 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2488 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2489 always_allow_tools(cx);
2490 let fake_model = model.as_fake();
2491
2492 let environment = Rc::new(cx.update(|cx| {
2493 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2494 }));
2495 let handle = environment.terminal_handle.clone().unwrap();
2496
2497 let mut events = thread
2498 .update(cx, |thread, cx| {
2499 thread.add_tool(crate::TerminalTool::new(
2500 thread.project().clone(),
2501 environment,
2502 ));
2503 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2504 })
2505 .unwrap();
2506
2507 cx.run_until_parked();
2508
2509 // Simulate the model calling the terminal tool with a short timeout
2510 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2511 LanguageModelToolUse {
2512 id: "terminal_tool_1".into(),
2513 name: TerminalTool::NAME.into(),
2514 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2515 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2516 is_input_complete: true,
2517 thought_signature: None,
2518 },
2519 ));
2520 fake_model.end_last_completion_stream();
2521
2522 // Wait for the terminal tool to start running
2523 wait_for_terminal_tool_started(&mut events, cx).await;
2524
2525 // Advance clock past the timeout
2526 cx.executor().advance_clock(Duration::from_millis(200));
2527 cx.run_until_parked();
2528
2529 // The thread continues after tool completion - simulate the model ending its turn
2530 fake_model
2531 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2532 fake_model.end_last_completion_stream();
2533
2534 // Collect remaining events
2535 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2536
2537 // Verify the terminal was killed due to timeout
2538 assert!(
2539 handle.was_killed(),
2540 "expected terminal handle to be killed on timeout"
2541 );
2542
2543 // Verify we got an EndTurn (the tool completed, just with timeout)
2544 assert_eq!(
2545 stop_events(remaining_events),
2546 vec![acp::StopReason::EndTurn],
2547 );
2548
2549 // Verify the tool result indicates timeout, not user stopped
2550 thread.update(cx, |thread, _cx| {
2551 let message = thread.last_message().unwrap();
2552 let agent_message = message.as_agent_message().unwrap();
2553
2554 let tool_use = agent_message
2555 .content
2556 .iter()
2557 .find_map(|content| match content {
2558 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2559 _ => None,
2560 })
2561 .expect("expected tool use in agent message");
2562
2563 let tool_result = agent_message
2564 .tool_results
2565 .get(&tool_use.id)
2566 .expect("expected tool result");
2567
2568 let result_text = match &tool_result.content {
2569 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2570 _ => panic!("expected text content in tool result"),
2571 };
2572
2573 assert!(
2574 result_text.contains("timed out"),
2575 "expected tool result to indicate timeout, got: {result_text}"
2576 );
2577 assert!(
2578 !result_text.contains("The user stopped"),
2579 "tool result should not mention user stopped when it timed out, got: {result_text}"
2580 );
2581 });
2582}
2583
2584#[gpui::test]
2585async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2586 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2587 let fake_model = model.as_fake();
2588
2589 let events_1 = thread
2590 .update(cx, |thread, cx| {
2591 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2592 })
2593 .unwrap();
2594 cx.run_until_parked();
2595 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2596 cx.run_until_parked();
2597
2598 let events_2 = thread
2599 .update(cx, |thread, cx| {
2600 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2601 })
2602 .unwrap();
2603 cx.run_until_parked();
2604 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2605 fake_model
2606 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2607 fake_model.end_last_completion_stream();
2608
2609 let events_1 = events_1.collect::<Vec<_>>().await;
2610 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2611 let events_2 = events_2.collect::<Vec<_>>().await;
2612 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2613}
2614
2615#[gpui::test]
2616async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2617 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2618 let fake_model = model.as_fake();
2619
2620 let events_1 = thread
2621 .update(cx, |thread, cx| {
2622 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2623 })
2624 .unwrap();
2625 cx.run_until_parked();
2626 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2627 fake_model
2628 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2629 fake_model.end_last_completion_stream();
2630 let events_1 = events_1.collect::<Vec<_>>().await;
2631
2632 let events_2 = thread
2633 .update(cx, |thread, cx| {
2634 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2635 })
2636 .unwrap();
2637 cx.run_until_parked();
2638 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2639 fake_model
2640 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2641 fake_model.end_last_completion_stream();
2642 let events_2 = events_2.collect::<Vec<_>>().await;
2643
2644 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2645 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2646}
2647
2648#[gpui::test]
2649async fn test_refusal(cx: &mut TestAppContext) {
2650 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2651 let fake_model = model.as_fake();
2652
2653 let events = thread
2654 .update(cx, |thread, cx| {
2655 thread.send(UserMessageId::new(), ["Hello"], cx)
2656 })
2657 .unwrap();
2658 cx.run_until_parked();
2659 thread.read_with(cx, |thread, _| {
2660 assert_eq!(
2661 thread.to_markdown(),
2662 indoc! {"
2663 ## User
2664
2665 Hello
2666 "}
2667 );
2668 });
2669
2670 fake_model.send_last_completion_stream_text_chunk("Hey!");
2671 cx.run_until_parked();
2672 thread.read_with(cx, |thread, _| {
2673 assert_eq!(
2674 thread.to_markdown(),
2675 indoc! {"
2676 ## User
2677
2678 Hello
2679
2680 ## Assistant
2681
2682 Hey!
2683 "}
2684 );
2685 });
2686
2687 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2688 fake_model
2689 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2690 let events = events.collect::<Vec<_>>().await;
2691 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2692 thread.read_with(cx, |thread, _| {
2693 assert_eq!(thread.to_markdown(), "");
2694 });
2695}
2696
2697#[gpui::test]
2698async fn test_truncate_first_message(cx: &mut TestAppContext) {
2699 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2700 let fake_model = model.as_fake();
2701
2702 let message_id = UserMessageId::new();
2703 thread
2704 .update(cx, |thread, cx| {
2705 thread.send(message_id.clone(), ["Hello"], cx)
2706 })
2707 .unwrap();
2708 cx.run_until_parked();
2709 thread.read_with(cx, |thread, _| {
2710 assert_eq!(
2711 thread.to_markdown(),
2712 indoc! {"
2713 ## User
2714
2715 Hello
2716 "}
2717 );
2718 assert_eq!(thread.latest_token_usage(), None);
2719 });
2720
2721 fake_model.send_last_completion_stream_text_chunk("Hey!");
2722 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2723 language_model::TokenUsage {
2724 input_tokens: 32_000,
2725 output_tokens: 16_000,
2726 cache_creation_input_tokens: 0,
2727 cache_read_input_tokens: 0,
2728 },
2729 ));
2730 cx.run_until_parked();
2731 thread.read_with(cx, |thread, _| {
2732 assert_eq!(
2733 thread.to_markdown(),
2734 indoc! {"
2735 ## User
2736
2737 Hello
2738
2739 ## Assistant
2740
2741 Hey!
2742 "}
2743 );
2744 assert_eq!(
2745 thread.latest_token_usage(),
2746 Some(acp_thread::TokenUsage {
2747 used_tokens: 32_000 + 16_000,
2748 max_tokens: 1_000_000,
2749 max_output_tokens: None,
2750 input_tokens: 32_000,
2751 output_tokens: 16_000,
2752 })
2753 );
2754 });
2755
2756 thread
2757 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2758 .unwrap();
2759 cx.run_until_parked();
2760 thread.read_with(cx, |thread, _| {
2761 assert_eq!(thread.to_markdown(), "");
2762 assert_eq!(thread.latest_token_usage(), None);
2763 });
2764
2765 // Ensure we can still send a new message after truncation.
2766 thread
2767 .update(cx, |thread, cx| {
2768 thread.send(UserMessageId::new(), ["Hi"], cx)
2769 })
2770 .unwrap();
2771 thread.update(cx, |thread, _cx| {
2772 assert_eq!(
2773 thread.to_markdown(),
2774 indoc! {"
2775 ## User
2776
2777 Hi
2778 "}
2779 );
2780 });
2781 cx.run_until_parked();
2782 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2783 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2784 language_model::TokenUsage {
2785 input_tokens: 40_000,
2786 output_tokens: 20_000,
2787 cache_creation_input_tokens: 0,
2788 cache_read_input_tokens: 0,
2789 },
2790 ));
2791 cx.run_until_parked();
2792 thread.read_with(cx, |thread, _| {
2793 assert_eq!(
2794 thread.to_markdown(),
2795 indoc! {"
2796 ## User
2797
2798 Hi
2799
2800 ## Assistant
2801
2802 Ahoy!
2803 "}
2804 );
2805
2806 assert_eq!(
2807 thread.latest_token_usage(),
2808 Some(acp_thread::TokenUsage {
2809 used_tokens: 40_000 + 20_000,
2810 max_tokens: 1_000_000,
2811 max_output_tokens: None,
2812 input_tokens: 40_000,
2813 output_tokens: 20_000,
2814 })
2815 );
2816 });
2817}
2818
2819#[gpui::test]
2820async fn test_truncate_second_message(cx: &mut TestAppContext) {
2821 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2822 let fake_model = model.as_fake();
2823
2824 thread
2825 .update(cx, |thread, cx| {
2826 thread.send(UserMessageId::new(), ["Message 1"], cx)
2827 })
2828 .unwrap();
2829 cx.run_until_parked();
2830 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2831 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2832 language_model::TokenUsage {
2833 input_tokens: 32_000,
2834 output_tokens: 16_000,
2835 cache_creation_input_tokens: 0,
2836 cache_read_input_tokens: 0,
2837 },
2838 ));
2839 fake_model.end_last_completion_stream();
2840 cx.run_until_parked();
2841
2842 let assert_first_message_state = |cx: &mut TestAppContext| {
2843 thread.clone().read_with(cx, |thread, _| {
2844 assert_eq!(
2845 thread.to_markdown(),
2846 indoc! {"
2847 ## User
2848
2849 Message 1
2850
2851 ## Assistant
2852
2853 Message 1 response
2854 "}
2855 );
2856
2857 assert_eq!(
2858 thread.latest_token_usage(),
2859 Some(acp_thread::TokenUsage {
2860 used_tokens: 32_000 + 16_000,
2861 max_tokens: 1_000_000,
2862 max_output_tokens: None,
2863 input_tokens: 32_000,
2864 output_tokens: 16_000,
2865 })
2866 );
2867 });
2868 };
2869
2870 assert_first_message_state(cx);
2871
2872 let second_message_id = UserMessageId::new();
2873 thread
2874 .update(cx, |thread, cx| {
2875 thread.send(second_message_id.clone(), ["Message 2"], cx)
2876 })
2877 .unwrap();
2878 cx.run_until_parked();
2879
2880 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2881 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2882 language_model::TokenUsage {
2883 input_tokens: 40_000,
2884 output_tokens: 20_000,
2885 cache_creation_input_tokens: 0,
2886 cache_read_input_tokens: 0,
2887 },
2888 ));
2889 fake_model.end_last_completion_stream();
2890 cx.run_until_parked();
2891
2892 thread.read_with(cx, |thread, _| {
2893 assert_eq!(
2894 thread.to_markdown(),
2895 indoc! {"
2896 ## User
2897
2898 Message 1
2899
2900 ## Assistant
2901
2902 Message 1 response
2903
2904 ## User
2905
2906 Message 2
2907
2908 ## Assistant
2909
2910 Message 2 response
2911 "}
2912 );
2913
2914 assert_eq!(
2915 thread.latest_token_usage(),
2916 Some(acp_thread::TokenUsage {
2917 used_tokens: 40_000 + 20_000,
2918 max_tokens: 1_000_000,
2919 max_output_tokens: None,
2920 input_tokens: 40_000,
2921 output_tokens: 20_000,
2922 })
2923 );
2924 });
2925
2926 thread
2927 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2928 .unwrap();
2929 cx.run_until_parked();
2930
2931 assert_first_message_state(cx);
2932}
2933
2934#[gpui::test]
2935async fn test_title_generation(cx: &mut TestAppContext) {
2936 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2937 let fake_model = model.as_fake();
2938
2939 let summary_model = Arc::new(FakeLanguageModel::default());
2940 thread.update(cx, |thread, cx| {
2941 thread.set_summarization_model(Some(summary_model.clone()), cx)
2942 });
2943
2944 let send = thread
2945 .update(cx, |thread, cx| {
2946 thread.send(UserMessageId::new(), ["Hello"], cx)
2947 })
2948 .unwrap();
2949 cx.run_until_parked();
2950
2951 fake_model.send_last_completion_stream_text_chunk("Hey!");
2952 fake_model.end_last_completion_stream();
2953 cx.run_until_parked();
2954 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2955
2956 // Ensure the summary model has been invoked to generate a title.
2957 summary_model.send_last_completion_stream_text_chunk("Hello ");
2958 summary_model.send_last_completion_stream_text_chunk("world\nG");
2959 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2960 summary_model.end_last_completion_stream();
2961 send.collect::<Vec<_>>().await;
2962 cx.run_until_parked();
2963 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2964
2965 // Send another message, ensuring no title is generated this time.
2966 let send = thread
2967 .update(cx, |thread, cx| {
2968 thread.send(UserMessageId::new(), ["Hello again"], cx)
2969 })
2970 .unwrap();
2971 cx.run_until_parked();
2972 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2973 fake_model.end_last_completion_stream();
2974 cx.run_until_parked();
2975 assert_eq!(summary_model.pending_completions(), Vec::new());
2976 send.collect::<Vec<_>>().await;
2977 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2978}
2979
2980#[gpui::test]
2981async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2982 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2983 let fake_model = model.as_fake();
2984
2985 let _events = thread
2986 .update(cx, |thread, cx| {
2987 thread.add_tool(ToolRequiringPermission);
2988 thread.add_tool(EchoTool);
2989 thread.send(UserMessageId::new(), ["Hey!"], cx)
2990 })
2991 .unwrap();
2992 cx.run_until_parked();
2993
2994 let permission_tool_use = LanguageModelToolUse {
2995 id: "tool_id_1".into(),
2996 name: ToolRequiringPermission::NAME.into(),
2997 raw_input: "{}".into(),
2998 input: json!({}),
2999 is_input_complete: true,
3000 thought_signature: None,
3001 };
3002 let echo_tool_use = LanguageModelToolUse {
3003 id: "tool_id_2".into(),
3004 name: EchoTool::NAME.into(),
3005 raw_input: json!({"text": "test"}).to_string(),
3006 input: json!({"text": "test"}),
3007 is_input_complete: true,
3008 thought_signature: None,
3009 };
3010 fake_model.send_last_completion_stream_text_chunk("Hi!");
3011 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3012 permission_tool_use,
3013 ));
3014 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3015 echo_tool_use.clone(),
3016 ));
3017 fake_model.end_last_completion_stream();
3018 cx.run_until_parked();
3019
3020 // Ensure pending tools are skipped when building a request.
3021 let request = thread
3022 .read_with(cx, |thread, cx| {
3023 thread.build_completion_request(CompletionIntent::EditFile, cx)
3024 })
3025 .unwrap();
3026 assert_eq!(
3027 request.messages[1..],
3028 vec![
3029 LanguageModelRequestMessage {
3030 role: Role::User,
3031 content: vec!["Hey!".into()],
3032 cache: true,
3033 reasoning_details: None,
3034 },
3035 LanguageModelRequestMessage {
3036 role: Role::Assistant,
3037 content: vec![
3038 MessageContent::Text("Hi!".into()),
3039 MessageContent::ToolUse(echo_tool_use.clone())
3040 ],
3041 cache: false,
3042 reasoning_details: None,
3043 },
3044 LanguageModelRequestMessage {
3045 role: Role::User,
3046 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
3047 tool_use_id: echo_tool_use.id.clone(),
3048 tool_name: echo_tool_use.name,
3049 is_error: false,
3050 content: "test".into(),
3051 output: Some("test".into())
3052 })],
3053 cache: false,
3054 reasoning_details: None,
3055 },
3056 ],
3057 );
3058}
3059
3060#[gpui::test]
3061async fn test_agent_connection(cx: &mut TestAppContext) {
3062 cx.update(settings::init);
3063 let templates = Templates::new();
3064
3065 // Initialize language model system with test provider
3066 cx.update(|cx| {
3067 gpui_tokio::init(cx);
3068
3069 let http_client = FakeHttpClient::with_404_response();
3070 let clock = Arc::new(clock::FakeSystemClock::new());
3071 let client = Client::new(clock, http_client, cx);
3072 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3073 language_model::init(client.clone(), cx);
3074 language_models::init(user_store, client.clone(), cx);
3075 LanguageModelRegistry::test(cx);
3076 });
3077 cx.executor().forbid_parking();
3078
3079 // Create a project for new_thread
3080 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
3081 fake_fs.insert_tree(path!("/test"), json!({})).await;
3082 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
3083 let cwd = Path::new("/test");
3084 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3085
3086 // Create agent and connection
3087 let agent = NativeAgent::new(
3088 project.clone(),
3089 thread_store,
3090 templates.clone(),
3091 None,
3092 fake_fs.clone(),
3093 &mut cx.to_async(),
3094 )
3095 .await
3096 .unwrap();
3097 let connection = NativeAgentConnection(agent.clone());
3098
3099 // Create a thread using new_thread
3100 let connection_rc = Rc::new(connection.clone());
3101 let acp_thread = cx
3102 .update(|cx| connection_rc.new_session(project, cwd, cx))
3103 .await
3104 .expect("new_thread should succeed");
3105
3106 // Get the session_id from the AcpThread
3107 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3108
3109 // Test model_selector returns Some
3110 let selector_opt = connection.model_selector(&session_id);
3111 assert!(
3112 selector_opt.is_some(),
3113 "agent should always support ModelSelector"
3114 );
3115 let selector = selector_opt.unwrap();
3116
3117 // Test list_models
3118 let listed_models = cx
3119 .update(|cx| selector.list_models(cx))
3120 .await
3121 .expect("list_models should succeed");
3122 let AgentModelList::Grouped(listed_models) = listed_models else {
3123 panic!("Unexpected model list type");
3124 };
3125 assert!(!listed_models.is_empty(), "should have at least one model");
3126 assert_eq!(
3127 listed_models[&AgentModelGroupName("Fake".into())][0]
3128 .id
3129 .0
3130 .as_ref(),
3131 "fake/fake"
3132 );
3133
3134 // Test selected_model returns the default
3135 let model = cx
3136 .update(|cx| selector.selected_model(cx))
3137 .await
3138 .expect("selected_model should succeed");
3139 let model = cx
3140 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
3141 .unwrap();
3142 let model = model.as_fake();
3143 assert_eq!(model.id().0, "fake", "should return default model");
3144
3145 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
3146 cx.run_until_parked();
3147 model.send_last_completion_stream_text_chunk("def");
3148 cx.run_until_parked();
3149 acp_thread.read_with(cx, |thread, cx| {
3150 assert_eq!(
3151 thread.to_markdown(cx),
3152 indoc! {"
3153 ## User
3154
3155 abc
3156
3157 ## Assistant
3158
3159 def
3160
3161 "}
3162 )
3163 });
3164
3165 // Test cancel
3166 cx.update(|cx| connection.cancel(&session_id, cx));
3167 request.await.expect("prompt should fail gracefully");
3168
3169 // Explicitly close the session and drop the ACP thread.
3170 cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
3171 .await
3172 .unwrap();
3173 drop(acp_thread);
3174 let result = cx
3175 .update(|cx| {
3176 connection.prompt(
3177 Some(acp_thread::UserMessageId::new()),
3178 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
3179 cx,
3180 )
3181 })
3182 .await;
3183 assert_eq!(
3184 result.as_ref().unwrap_err().to_string(),
3185 "Session not found",
3186 "unexpected result: {:?}",
3187 result
3188 );
3189}
3190
3191#[gpui::test]
3192async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
3193 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3194 thread.update(cx, |thread, _cx| thread.add_tool(EchoTool));
3195 let fake_model = model.as_fake();
3196
3197 let mut events = thread
3198 .update(cx, |thread, cx| {
3199 thread.send(UserMessageId::new(), ["Echo something"], cx)
3200 })
3201 .unwrap();
3202 cx.run_until_parked();
3203
3204 // Simulate streaming partial input.
3205 let input = json!({});
3206 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3207 LanguageModelToolUse {
3208 id: "1".into(),
3209 name: EchoTool::NAME.into(),
3210 raw_input: input.to_string(),
3211 input,
3212 is_input_complete: false,
3213 thought_signature: None,
3214 },
3215 ));
3216
3217 // Input streaming completed
3218 let input = json!({ "text": "Hello!" });
3219 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3220 LanguageModelToolUse {
3221 id: "1".into(),
3222 name: "echo".into(),
3223 raw_input: input.to_string(),
3224 input,
3225 is_input_complete: true,
3226 thought_signature: None,
3227 },
3228 ));
3229 fake_model.end_last_completion_stream();
3230 cx.run_until_parked();
3231
3232 let tool_call = expect_tool_call(&mut events).await;
3233 assert_eq!(
3234 tool_call,
3235 acp::ToolCall::new("1", "Echo")
3236 .raw_input(json!({}))
3237 .meta(acp::Meta::from_iter([("tool_name".into(), "echo".into())]))
3238 );
3239 let update = expect_tool_call_update_fields(&mut events).await;
3240 assert_eq!(
3241 update,
3242 acp::ToolCallUpdate::new(
3243 "1",
3244 acp::ToolCallUpdateFields::new()
3245 .title("Echo")
3246 .kind(acp::ToolKind::Other)
3247 .raw_input(json!({ "text": "Hello!"}))
3248 )
3249 );
3250 let update = expect_tool_call_update_fields(&mut events).await;
3251 assert_eq!(
3252 update,
3253 acp::ToolCallUpdate::new(
3254 "1",
3255 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3256 )
3257 );
3258 let update = expect_tool_call_update_fields(&mut events).await;
3259 assert_eq!(
3260 update,
3261 acp::ToolCallUpdate::new(
3262 "1",
3263 acp::ToolCallUpdateFields::new()
3264 .status(acp::ToolCallStatus::Completed)
3265 .raw_output("Hello!")
3266 )
3267 );
3268}
3269
3270#[gpui::test]
3271async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3272 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3273 let fake_model = model.as_fake();
3274
3275 let mut events = thread
3276 .update(cx, |thread, cx| {
3277 thread.send(UserMessageId::new(), ["Hello!"], cx)
3278 })
3279 .unwrap();
3280 cx.run_until_parked();
3281
3282 fake_model.send_last_completion_stream_text_chunk("Hey!");
3283 fake_model.end_last_completion_stream();
3284
3285 let mut retry_events = Vec::new();
3286 while let Some(Ok(event)) = events.next().await {
3287 match event {
3288 ThreadEvent::Retry(retry_status) => {
3289 retry_events.push(retry_status);
3290 }
3291 ThreadEvent::Stop(..) => break,
3292 _ => {}
3293 }
3294 }
3295
3296 assert_eq!(retry_events.len(), 0);
3297 thread.read_with(cx, |thread, _cx| {
3298 assert_eq!(
3299 thread.to_markdown(),
3300 indoc! {"
3301 ## User
3302
3303 Hello!
3304
3305 ## Assistant
3306
3307 Hey!
3308 "}
3309 )
3310 });
3311}
3312
3313#[gpui::test]
3314async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3315 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3316 let fake_model = model.as_fake();
3317
3318 let mut events = thread
3319 .update(cx, |thread, cx| {
3320 thread.send(UserMessageId::new(), ["Hello!"], cx)
3321 })
3322 .unwrap();
3323 cx.run_until_parked();
3324
3325 fake_model.send_last_completion_stream_text_chunk("Hey,");
3326 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3327 provider: LanguageModelProviderName::new("Anthropic"),
3328 retry_after: Some(Duration::from_secs(3)),
3329 });
3330 fake_model.end_last_completion_stream();
3331
3332 cx.executor().advance_clock(Duration::from_secs(3));
3333 cx.run_until_parked();
3334
3335 fake_model.send_last_completion_stream_text_chunk("there!");
3336 fake_model.end_last_completion_stream();
3337 cx.run_until_parked();
3338
3339 let mut retry_events = Vec::new();
3340 while let Some(Ok(event)) = events.next().await {
3341 match event {
3342 ThreadEvent::Retry(retry_status) => {
3343 retry_events.push(retry_status);
3344 }
3345 ThreadEvent::Stop(..) => break,
3346 _ => {}
3347 }
3348 }
3349
3350 assert_eq!(retry_events.len(), 1);
3351 assert!(matches!(
3352 retry_events[0],
3353 acp_thread::RetryStatus { attempt: 1, .. }
3354 ));
3355 thread.read_with(cx, |thread, _cx| {
3356 assert_eq!(
3357 thread.to_markdown(),
3358 indoc! {"
3359 ## User
3360
3361 Hello!
3362
3363 ## Assistant
3364
3365 Hey,
3366
3367 [resume]
3368
3369 ## Assistant
3370
3371 there!
3372 "}
3373 )
3374 });
3375}
3376
3377#[gpui::test]
3378async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3379 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3380 let fake_model = model.as_fake();
3381
3382 let events = thread
3383 .update(cx, |thread, cx| {
3384 thread.add_tool(EchoTool);
3385 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3386 })
3387 .unwrap();
3388 cx.run_until_parked();
3389
3390 let tool_use_1 = LanguageModelToolUse {
3391 id: "tool_1".into(),
3392 name: EchoTool::NAME.into(),
3393 raw_input: json!({"text": "test"}).to_string(),
3394 input: json!({"text": "test"}),
3395 is_input_complete: true,
3396 thought_signature: None,
3397 };
3398 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3399 tool_use_1.clone(),
3400 ));
3401 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3402 provider: LanguageModelProviderName::new("Anthropic"),
3403 retry_after: Some(Duration::from_secs(3)),
3404 });
3405 fake_model.end_last_completion_stream();
3406
3407 cx.executor().advance_clock(Duration::from_secs(3));
3408 let completion = fake_model.pending_completions().pop().unwrap();
3409 assert_eq!(
3410 completion.messages[1..],
3411 vec![
3412 LanguageModelRequestMessage {
3413 role: Role::User,
3414 content: vec!["Call the echo tool!".into()],
3415 cache: false,
3416 reasoning_details: None,
3417 },
3418 LanguageModelRequestMessage {
3419 role: Role::Assistant,
3420 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3421 cache: false,
3422 reasoning_details: None,
3423 },
3424 LanguageModelRequestMessage {
3425 role: Role::User,
3426 content: vec![language_model::MessageContent::ToolResult(
3427 LanguageModelToolResult {
3428 tool_use_id: tool_use_1.id.clone(),
3429 tool_name: tool_use_1.name.clone(),
3430 is_error: false,
3431 content: "test".into(),
3432 output: Some("test".into())
3433 }
3434 )],
3435 cache: true,
3436 reasoning_details: None,
3437 },
3438 ]
3439 );
3440
3441 fake_model.send_last_completion_stream_text_chunk("Done");
3442 fake_model.end_last_completion_stream();
3443 cx.run_until_parked();
3444 events.collect::<Vec<_>>().await;
3445 thread.read_with(cx, |thread, _cx| {
3446 assert_eq!(
3447 thread.last_message(),
3448 Some(Message::Agent(AgentMessage {
3449 content: vec![AgentMessageContent::Text("Done".into())],
3450 tool_results: IndexMap::default(),
3451 reasoning_details: None,
3452 }))
3453 );
3454 })
3455}
3456
3457#[gpui::test]
3458async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3459 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3460 let fake_model = model.as_fake();
3461
3462 let mut events = thread
3463 .update(cx, |thread, cx| {
3464 thread.send(UserMessageId::new(), ["Hello!"], cx)
3465 })
3466 .unwrap();
3467 cx.run_until_parked();
3468
3469 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3470 fake_model.send_last_completion_stream_error(
3471 LanguageModelCompletionError::ServerOverloaded {
3472 provider: LanguageModelProviderName::new("Anthropic"),
3473 retry_after: Some(Duration::from_secs(3)),
3474 },
3475 );
3476 fake_model.end_last_completion_stream();
3477 cx.executor().advance_clock(Duration::from_secs(3));
3478 cx.run_until_parked();
3479 }
3480
3481 let mut errors = Vec::new();
3482 let mut retry_events = Vec::new();
3483 while let Some(event) = events.next().await {
3484 match event {
3485 Ok(ThreadEvent::Retry(retry_status)) => {
3486 retry_events.push(retry_status);
3487 }
3488 Ok(ThreadEvent::Stop(..)) => break,
3489 Err(error) => errors.push(error),
3490 _ => {}
3491 }
3492 }
3493
3494 assert_eq!(
3495 retry_events.len(),
3496 crate::thread::MAX_RETRY_ATTEMPTS as usize
3497 );
3498 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3499 assert_eq!(retry_events[i].attempt, i + 1);
3500 }
3501 assert_eq!(errors.len(), 1);
3502 let error = errors[0]
3503 .downcast_ref::<LanguageModelCompletionError>()
3504 .unwrap();
3505 assert!(matches!(
3506 error,
3507 LanguageModelCompletionError::ServerOverloaded { .. }
3508 ));
3509}
3510
3511/// Filters out the stop events for asserting against in tests
3512fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3513 result_events
3514 .into_iter()
3515 .filter_map(|event| match event.unwrap() {
3516 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3517 _ => None,
3518 })
3519 .collect()
3520}
3521
3522struct ThreadTest {
3523 model: Arc<dyn LanguageModel>,
3524 thread: Entity<Thread>,
3525 project_context: Entity<ProjectContext>,
3526 context_server_store: Entity<ContextServerStore>,
3527 fs: Arc<FakeFs>,
3528}
3529
3530enum TestModel {
3531 Sonnet4,
3532 Fake,
3533}
3534
3535impl TestModel {
3536 fn id(&self) -> LanguageModelId {
3537 match self {
3538 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3539 TestModel::Fake => unreachable!(),
3540 }
3541 }
3542}
3543
3544async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3545 cx.executor().allow_parking();
3546
3547 let fs = FakeFs::new(cx.background_executor.clone());
3548 fs.create_dir(paths::settings_file().parent().unwrap())
3549 .await
3550 .unwrap();
3551 fs.insert_file(
3552 paths::settings_file(),
3553 json!({
3554 "agent": {
3555 "default_profile": "test-profile",
3556 "profiles": {
3557 "test-profile": {
3558 "name": "Test Profile",
3559 "tools": {
3560 EchoTool::NAME: true,
3561 DelayTool::NAME: true,
3562 WordListTool::NAME: true,
3563 ToolRequiringPermission::NAME: true,
3564 InfiniteTool::NAME: true,
3565 CancellationAwareTool::NAME: true,
3566 (TerminalTool::NAME): true,
3567 }
3568 }
3569 }
3570 }
3571 })
3572 .to_string()
3573 .into_bytes(),
3574 )
3575 .await;
3576
3577 cx.update(|cx| {
3578 settings::init(cx);
3579
3580 match model {
3581 TestModel::Fake => {}
3582 TestModel::Sonnet4 => {
3583 gpui_tokio::init(cx);
3584 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3585 cx.set_http_client(Arc::new(http_client));
3586 let client = Client::production(cx);
3587 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3588 language_model::init(client.clone(), cx);
3589 language_models::init(user_store, client.clone(), cx);
3590 }
3591 };
3592
3593 watch_settings(fs.clone(), cx);
3594 });
3595
3596 let templates = Templates::new();
3597
3598 fs.insert_tree(path!("/test"), json!({})).await;
3599 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3600
3601 let model = cx
3602 .update(|cx| {
3603 if let TestModel::Fake = model {
3604 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3605 } else {
3606 let model_id = model.id();
3607 let models = LanguageModelRegistry::read_global(cx);
3608 let model = models
3609 .available_models(cx)
3610 .find(|model| model.id() == model_id)
3611 .unwrap();
3612
3613 let provider = models.provider(&model.provider_id()).unwrap();
3614 let authenticated = provider.authenticate(cx);
3615
3616 cx.spawn(async move |_cx| {
3617 authenticated.await.unwrap();
3618 model
3619 })
3620 }
3621 })
3622 .await;
3623
3624 let project_context = cx.new(|_cx| ProjectContext::default());
3625 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3626 let context_server_registry =
3627 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3628 let thread = cx.new(|cx| {
3629 Thread::new(
3630 project,
3631 project_context.clone(),
3632 context_server_registry,
3633 templates,
3634 Some(model.clone()),
3635 cx,
3636 )
3637 });
3638 ThreadTest {
3639 model,
3640 thread,
3641 project_context,
3642 context_server_store,
3643 fs,
3644 }
3645}
3646
3647#[cfg(test)]
3648#[ctor::ctor]
3649fn init_logger() {
3650 if std::env::var("RUST_LOG").is_ok() {
3651 env_logger::init();
3652 }
3653}
3654
3655fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3656 let fs = fs.clone();
3657 cx.spawn({
3658 async move |cx| {
3659 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3660 cx.background_executor(),
3661 fs,
3662 paths::settings_file().clone(),
3663 );
3664 let _watcher_task = watcher_task;
3665
3666 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3667 cx.update(|cx| {
3668 SettingsStore::update_global(cx, |settings, cx| {
3669 settings.set_user_settings(&new_settings_content, cx)
3670 })
3671 })
3672 .ok();
3673 }
3674 }
3675 })
3676 .detach();
3677}
3678
3679fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3680 completion
3681 .tools
3682 .iter()
3683 .map(|tool| tool.name.clone())
3684 .collect()
3685}
3686
3687fn setup_context_server(
3688 name: &'static str,
3689 tools: Vec<context_server::types::Tool>,
3690 context_server_store: &Entity<ContextServerStore>,
3691 cx: &mut TestAppContext,
3692) -> mpsc::UnboundedReceiver<(
3693 context_server::types::CallToolParams,
3694 oneshot::Sender<context_server::types::CallToolResponse>,
3695)> {
3696 cx.update(|cx| {
3697 let mut settings = ProjectSettings::get_global(cx).clone();
3698 settings.context_servers.insert(
3699 name.into(),
3700 project::project_settings::ContextServerSettings::Stdio {
3701 enabled: true,
3702 remote: false,
3703 command: ContextServerCommand {
3704 path: "somebinary".into(),
3705 args: Vec::new(),
3706 env: None,
3707 timeout: None,
3708 },
3709 },
3710 );
3711 ProjectSettings::override_global(settings, cx);
3712 });
3713
3714 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3715 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3716 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3717 context_server::types::InitializeResponse {
3718 protocol_version: context_server::types::ProtocolVersion(
3719 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3720 ),
3721 server_info: context_server::types::Implementation {
3722 name: name.into(),
3723 version: "1.0.0".to_string(),
3724 },
3725 capabilities: context_server::types::ServerCapabilities {
3726 tools: Some(context_server::types::ToolsCapabilities {
3727 list_changed: Some(true),
3728 }),
3729 ..Default::default()
3730 },
3731 meta: None,
3732 }
3733 })
3734 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3735 let tools = tools.clone();
3736 async move {
3737 context_server::types::ListToolsResponse {
3738 tools,
3739 next_cursor: None,
3740 meta: None,
3741 }
3742 }
3743 })
3744 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3745 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3746 async move {
3747 let (response_tx, response_rx) = oneshot::channel();
3748 mcp_tool_calls_tx
3749 .unbounded_send((params, response_tx))
3750 .unwrap();
3751 response_rx.await.unwrap()
3752 }
3753 });
3754 context_server_store.update(cx, |store, cx| {
3755 store.start_server(
3756 Arc::new(ContextServer::new(
3757 ContextServerId(name.into()),
3758 Arc::new(fake_transport),
3759 )),
3760 cx,
3761 );
3762 });
3763 cx.run_until_parked();
3764 mcp_tool_calls_rx
3765}
3766
3767#[gpui::test]
3768async fn test_tokens_before_message(cx: &mut TestAppContext) {
3769 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3770 let fake_model = model.as_fake();
3771
3772 // First message
3773 let message_1_id = UserMessageId::new();
3774 thread
3775 .update(cx, |thread, cx| {
3776 thread.send(message_1_id.clone(), ["First message"], cx)
3777 })
3778 .unwrap();
3779 cx.run_until_parked();
3780
3781 // Before any response, tokens_before_message should return None for first message
3782 thread.read_with(cx, |thread, _| {
3783 assert_eq!(
3784 thread.tokens_before_message(&message_1_id),
3785 None,
3786 "First message should have no tokens before it"
3787 );
3788 });
3789
3790 // Complete first message with usage
3791 fake_model.send_last_completion_stream_text_chunk("Response 1");
3792 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3793 language_model::TokenUsage {
3794 input_tokens: 100,
3795 output_tokens: 50,
3796 cache_creation_input_tokens: 0,
3797 cache_read_input_tokens: 0,
3798 },
3799 ));
3800 fake_model.end_last_completion_stream();
3801 cx.run_until_parked();
3802
3803 // First message still has no tokens before it
3804 thread.read_with(cx, |thread, _| {
3805 assert_eq!(
3806 thread.tokens_before_message(&message_1_id),
3807 None,
3808 "First message should still have no tokens before it after response"
3809 );
3810 });
3811
3812 // Second message
3813 let message_2_id = UserMessageId::new();
3814 thread
3815 .update(cx, |thread, cx| {
3816 thread.send(message_2_id.clone(), ["Second message"], cx)
3817 })
3818 .unwrap();
3819 cx.run_until_parked();
3820
3821 // Second message should have first message's input tokens before it
3822 thread.read_with(cx, |thread, _| {
3823 assert_eq!(
3824 thread.tokens_before_message(&message_2_id),
3825 Some(100),
3826 "Second message should have 100 tokens before it (from first request)"
3827 );
3828 });
3829
3830 // Complete second message
3831 fake_model.send_last_completion_stream_text_chunk("Response 2");
3832 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3833 language_model::TokenUsage {
3834 input_tokens: 250, // Total for this request (includes previous context)
3835 output_tokens: 75,
3836 cache_creation_input_tokens: 0,
3837 cache_read_input_tokens: 0,
3838 },
3839 ));
3840 fake_model.end_last_completion_stream();
3841 cx.run_until_parked();
3842
3843 // Third message
3844 let message_3_id = UserMessageId::new();
3845 thread
3846 .update(cx, |thread, cx| {
3847 thread.send(message_3_id.clone(), ["Third message"], cx)
3848 })
3849 .unwrap();
3850 cx.run_until_parked();
3851
3852 // Third message should have second message's input tokens (250) before it
3853 thread.read_with(cx, |thread, _| {
3854 assert_eq!(
3855 thread.tokens_before_message(&message_3_id),
3856 Some(250),
3857 "Third message should have 250 tokens before it (from second request)"
3858 );
3859 // Second message should still have 100
3860 assert_eq!(
3861 thread.tokens_before_message(&message_2_id),
3862 Some(100),
3863 "Second message should still have 100 tokens before it"
3864 );
3865 // First message still has none
3866 assert_eq!(
3867 thread.tokens_before_message(&message_1_id),
3868 None,
3869 "First message should still have no tokens before it"
3870 );
3871 });
3872}
3873
3874#[gpui::test]
3875async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3876 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3877 let fake_model = model.as_fake();
3878
3879 // Set up three messages with responses
3880 let message_1_id = UserMessageId::new();
3881 thread
3882 .update(cx, |thread, cx| {
3883 thread.send(message_1_id.clone(), ["Message 1"], cx)
3884 })
3885 .unwrap();
3886 cx.run_until_parked();
3887 fake_model.send_last_completion_stream_text_chunk("Response 1");
3888 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3889 language_model::TokenUsage {
3890 input_tokens: 100,
3891 output_tokens: 50,
3892 cache_creation_input_tokens: 0,
3893 cache_read_input_tokens: 0,
3894 },
3895 ));
3896 fake_model.end_last_completion_stream();
3897 cx.run_until_parked();
3898
3899 let message_2_id = UserMessageId::new();
3900 thread
3901 .update(cx, |thread, cx| {
3902 thread.send(message_2_id.clone(), ["Message 2"], cx)
3903 })
3904 .unwrap();
3905 cx.run_until_parked();
3906 fake_model.send_last_completion_stream_text_chunk("Response 2");
3907 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3908 language_model::TokenUsage {
3909 input_tokens: 250,
3910 output_tokens: 75,
3911 cache_creation_input_tokens: 0,
3912 cache_read_input_tokens: 0,
3913 },
3914 ));
3915 fake_model.end_last_completion_stream();
3916 cx.run_until_parked();
3917
3918 // Verify initial state
3919 thread.read_with(cx, |thread, _| {
3920 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3921 });
3922
3923 // Truncate at message 2 (removes message 2 and everything after)
3924 thread
3925 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3926 .unwrap();
3927 cx.run_until_parked();
3928
3929 // After truncation, message_2_id no longer exists, so lookup should return None
3930 thread.read_with(cx, |thread, _| {
3931 assert_eq!(
3932 thread.tokens_before_message(&message_2_id),
3933 None,
3934 "After truncation, message 2 no longer exists"
3935 );
3936 // Message 1 still exists but has no tokens before it
3937 assert_eq!(
3938 thread.tokens_before_message(&message_1_id),
3939 None,
3940 "First message still has no tokens before it"
3941 );
3942 });
3943}
3944
3945#[gpui::test]
3946async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3947 init_test(cx);
3948
3949 let fs = FakeFs::new(cx.executor());
3950 fs.insert_tree("/root", json!({})).await;
3951 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3952
3953 // Test 1: Deny rule blocks command
3954 {
3955 let environment = Rc::new(cx.update(|cx| {
3956 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
3957 }));
3958
3959 cx.update(|cx| {
3960 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3961 settings.tool_permissions.tools.insert(
3962 TerminalTool::NAME.into(),
3963 agent_settings::ToolRules {
3964 default: Some(settings::ToolPermissionMode::Confirm),
3965 always_allow: vec![],
3966 always_deny: vec![
3967 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3968 ],
3969 always_confirm: vec![],
3970 invalid_patterns: vec![],
3971 },
3972 );
3973 agent_settings::AgentSettings::override_global(settings, cx);
3974 });
3975
3976 #[allow(clippy::arc_with_non_send_sync)]
3977 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3978 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3979
3980 let task = cx.update(|cx| {
3981 tool.run(
3982 ToolInput::resolved(crate::TerminalToolInput {
3983 command: "rm -rf /".to_string(),
3984 cd: ".".to_string(),
3985 timeout_ms: None,
3986 }),
3987 event_stream,
3988 cx,
3989 )
3990 });
3991
3992 let result = task.await;
3993 assert!(
3994 result.is_err(),
3995 "expected command to be blocked by deny rule"
3996 );
3997 let err_msg = result.unwrap_err().to_lowercase();
3998 assert!(
3999 err_msg.contains("blocked"),
4000 "error should mention the command was blocked"
4001 );
4002 }
4003
4004 // Test 2: Allow rule skips confirmation (and overrides default: Deny)
4005 {
4006 let environment = Rc::new(cx.update(|cx| {
4007 FakeThreadEnvironment::default()
4008 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4009 }));
4010
4011 cx.update(|cx| {
4012 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4013 settings.tool_permissions.tools.insert(
4014 TerminalTool::NAME.into(),
4015 agent_settings::ToolRules {
4016 default: Some(settings::ToolPermissionMode::Deny),
4017 always_allow: vec![
4018 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
4019 ],
4020 always_deny: vec![],
4021 always_confirm: vec![],
4022 invalid_patterns: vec![],
4023 },
4024 );
4025 agent_settings::AgentSettings::override_global(settings, cx);
4026 });
4027
4028 #[allow(clippy::arc_with_non_send_sync)]
4029 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4030 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4031
4032 let task = cx.update(|cx| {
4033 tool.run(
4034 ToolInput::resolved(crate::TerminalToolInput {
4035 command: "echo hello".to_string(),
4036 cd: ".".to_string(),
4037 timeout_ms: None,
4038 }),
4039 event_stream,
4040 cx,
4041 )
4042 });
4043
4044 let update = rx.expect_update_fields().await;
4045 assert!(
4046 update.content.iter().any(|blocks| {
4047 blocks
4048 .iter()
4049 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
4050 }),
4051 "expected terminal content (allow rule should skip confirmation and override default deny)"
4052 );
4053
4054 let result = task.await;
4055 assert!(
4056 result.is_ok(),
4057 "expected command to succeed without confirmation"
4058 );
4059 }
4060
4061 // Test 3: global default: allow does NOT override always_confirm patterns
4062 {
4063 let environment = Rc::new(cx.update(|cx| {
4064 FakeThreadEnvironment::default()
4065 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4066 }));
4067
4068 cx.update(|cx| {
4069 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4070 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4071 settings.tool_permissions.tools.insert(
4072 TerminalTool::NAME.into(),
4073 agent_settings::ToolRules {
4074 default: Some(settings::ToolPermissionMode::Allow),
4075 always_allow: vec![],
4076 always_deny: vec![],
4077 always_confirm: vec![
4078 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
4079 ],
4080 invalid_patterns: vec![],
4081 },
4082 );
4083 agent_settings::AgentSettings::override_global(settings, cx);
4084 });
4085
4086 #[allow(clippy::arc_with_non_send_sync)]
4087 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4088 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4089
4090 let _task = cx.update(|cx| {
4091 tool.run(
4092 ToolInput::resolved(crate::TerminalToolInput {
4093 command: "sudo rm file".to_string(),
4094 cd: ".".to_string(),
4095 timeout_ms: None,
4096 }),
4097 event_stream,
4098 cx,
4099 )
4100 });
4101
4102 // With global default: allow, confirm patterns are still respected
4103 // The expect_authorization() call will panic if no authorization is requested,
4104 // which validates that the confirm pattern still triggers confirmation
4105 let _auth = rx.expect_authorization().await;
4106
4107 drop(_task);
4108 }
4109
4110 // Test 4: tool-specific default: deny is respected even with global default: allow
4111 {
4112 let environment = Rc::new(cx.update(|cx| {
4113 FakeThreadEnvironment::default()
4114 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4115 }));
4116
4117 cx.update(|cx| {
4118 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4119 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4120 settings.tool_permissions.tools.insert(
4121 TerminalTool::NAME.into(),
4122 agent_settings::ToolRules {
4123 default: Some(settings::ToolPermissionMode::Deny),
4124 always_allow: vec![],
4125 always_deny: vec![],
4126 always_confirm: vec![],
4127 invalid_patterns: vec![],
4128 },
4129 );
4130 agent_settings::AgentSettings::override_global(settings, cx);
4131 });
4132
4133 #[allow(clippy::arc_with_non_send_sync)]
4134 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4135 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4136
4137 let task = cx.update(|cx| {
4138 tool.run(
4139 ToolInput::resolved(crate::TerminalToolInput {
4140 command: "echo hello".to_string(),
4141 cd: ".".to_string(),
4142 timeout_ms: None,
4143 }),
4144 event_stream,
4145 cx,
4146 )
4147 });
4148
4149 // tool-specific default: deny is respected even with global default: allow
4150 let result = task.await;
4151 assert!(
4152 result.is_err(),
4153 "expected command to be blocked by tool-specific deny default"
4154 );
4155 let err_msg = result.unwrap_err().to_lowercase();
4156 assert!(
4157 err_msg.contains("disabled"),
4158 "error should mention the tool is disabled, got: {err_msg}"
4159 );
4160 }
4161}
4162
4163#[gpui::test]
4164async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
4165 init_test(cx);
4166 cx.update(|cx| {
4167 LanguageModelRegistry::test(cx);
4168 });
4169 cx.update(|cx| {
4170 cx.update_flags(true, vec!["subagents".to_string()]);
4171 });
4172
4173 let fs = FakeFs::new(cx.executor());
4174 fs.insert_tree(
4175 "/",
4176 json!({
4177 "a": {
4178 "b.md": "Lorem"
4179 }
4180 }),
4181 )
4182 .await;
4183 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4184 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4185 let agent = NativeAgent::new(
4186 project.clone(),
4187 thread_store.clone(),
4188 Templates::new(),
4189 None,
4190 fs.clone(),
4191 &mut cx.to_async(),
4192 )
4193 .await
4194 .unwrap();
4195 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4196
4197 let acp_thread = cx
4198 .update(|cx| {
4199 connection
4200 .clone()
4201 .new_session(project.clone(), Path::new(""), cx)
4202 })
4203 .await
4204 .unwrap();
4205 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4206 let thread = agent.read_with(cx, |agent, _| {
4207 agent.sessions.get(&session_id).unwrap().thread.clone()
4208 });
4209 let model = Arc::new(FakeLanguageModel::default());
4210
4211 // Ensure empty threads are not saved, even if they get mutated.
4212 thread.update(cx, |thread, cx| {
4213 thread.set_model(model.clone(), cx);
4214 });
4215 cx.run_until_parked();
4216
4217 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4218 cx.run_until_parked();
4219 model.send_last_completion_stream_text_chunk("spawning subagent");
4220 let subagent_tool_input = SpawnAgentToolInput {
4221 label: "label".to_string(),
4222 message: "subagent task prompt".to_string(),
4223 session_id: None,
4224 };
4225 let subagent_tool_use = LanguageModelToolUse {
4226 id: "subagent_1".into(),
4227 name: SpawnAgentTool::NAME.into(),
4228 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4229 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4230 is_input_complete: true,
4231 thought_signature: None,
4232 };
4233 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4234 subagent_tool_use,
4235 ));
4236 model.end_last_completion_stream();
4237
4238 cx.run_until_parked();
4239
4240 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4241 thread
4242 .running_subagent_ids(cx)
4243 .get(0)
4244 .expect("subagent thread should be running")
4245 .clone()
4246 });
4247
4248 let subagent_thread = agent.read_with(cx, |agent, _cx| {
4249 agent
4250 .sessions
4251 .get(&subagent_session_id)
4252 .expect("subagent session should exist")
4253 .acp_thread
4254 .clone()
4255 });
4256
4257 model.send_last_completion_stream_text_chunk("subagent task response");
4258 model.end_last_completion_stream();
4259
4260 cx.run_until_parked();
4261
4262 assert_eq!(
4263 subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4264 indoc! {"
4265 ## User
4266
4267 subagent task prompt
4268
4269 ## Assistant
4270
4271 subagent task response
4272
4273 "}
4274 );
4275
4276 model.send_last_completion_stream_text_chunk("Response");
4277 model.end_last_completion_stream();
4278
4279 send.await.unwrap();
4280
4281 assert_eq!(
4282 acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4283 indoc! {r#"
4284 ## User
4285
4286 Prompt
4287
4288 ## Assistant
4289
4290 spawning subagent
4291
4292 **Tool Call: label**
4293 Status: Completed
4294
4295 subagent task response
4296
4297
4298 ## Assistant
4299
4300 Response
4301
4302 "#},
4303 );
4304}
4305
4306#[gpui::test]
4307async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAppContext) {
4308 init_test(cx);
4309 cx.update(|cx| {
4310 LanguageModelRegistry::test(cx);
4311 });
4312 cx.update(|cx| {
4313 cx.update_flags(true, vec!["subagents".to_string()]);
4314 });
4315
4316 let fs = FakeFs::new(cx.executor());
4317 fs.insert_tree(
4318 "/",
4319 json!({
4320 "a": {
4321 "b.md": "Lorem"
4322 }
4323 }),
4324 )
4325 .await;
4326 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4327 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4328 let agent = NativeAgent::new(
4329 project.clone(),
4330 thread_store.clone(),
4331 Templates::new(),
4332 None,
4333 fs.clone(),
4334 &mut cx.to_async(),
4335 )
4336 .await
4337 .unwrap();
4338 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4339
4340 let acp_thread = cx
4341 .update(|cx| {
4342 connection
4343 .clone()
4344 .new_session(project.clone(), Path::new(""), cx)
4345 })
4346 .await
4347 .unwrap();
4348 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4349 let thread = agent.read_with(cx, |agent, _| {
4350 agent.sessions.get(&session_id).unwrap().thread.clone()
4351 });
4352 let model = Arc::new(FakeLanguageModel::default());
4353
4354 // Ensure empty threads are not saved, even if they get mutated.
4355 thread.update(cx, |thread, cx| {
4356 thread.set_model(model.clone(), cx);
4357 });
4358 cx.run_until_parked();
4359
4360 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4361 cx.run_until_parked();
4362 model.send_last_completion_stream_text_chunk("spawning subagent");
4363 let subagent_tool_input = SpawnAgentToolInput {
4364 label: "label".to_string(),
4365 message: "subagent task prompt".to_string(),
4366 session_id: None,
4367 };
4368 let subagent_tool_use = LanguageModelToolUse {
4369 id: "subagent_1".into(),
4370 name: SpawnAgentTool::NAME.into(),
4371 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4372 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4373 is_input_complete: true,
4374 thought_signature: None,
4375 };
4376 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4377 subagent_tool_use,
4378 ));
4379 model.end_last_completion_stream();
4380
4381 cx.run_until_parked();
4382
4383 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4384 thread
4385 .running_subagent_ids(cx)
4386 .get(0)
4387 .expect("subagent thread should be running")
4388 .clone()
4389 });
4390 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4391 agent
4392 .sessions
4393 .get(&subagent_session_id)
4394 .expect("subagent session should exist")
4395 .acp_thread
4396 .clone()
4397 });
4398
4399 // model.send_last_completion_stream_text_chunk("subagent task response");
4400 // model.end_last_completion_stream();
4401
4402 // cx.run_until_parked();
4403
4404 acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
4405
4406 cx.run_until_parked();
4407
4408 send.await.unwrap();
4409
4410 acp_thread.read_with(cx, |thread, cx| {
4411 assert_eq!(thread.status(), ThreadStatus::Idle);
4412 assert_eq!(
4413 thread.to_markdown(cx),
4414 indoc! {"
4415 ## User
4416
4417 Prompt
4418
4419 ## Assistant
4420
4421 spawning subagent
4422
4423 **Tool Call: label**
4424 Status: Canceled
4425
4426 "}
4427 );
4428 });
4429 subagent_acp_thread.read_with(cx, |thread, cx| {
4430 assert_eq!(thread.status(), ThreadStatus::Idle);
4431 assert_eq!(
4432 thread.to_markdown(cx),
4433 indoc! {"
4434 ## User
4435
4436 subagent task prompt
4437
4438 "}
4439 );
4440 });
4441}
4442
4443#[gpui::test]
4444async fn test_subagent_tool_resume_session(cx: &mut TestAppContext) {
4445 init_test(cx);
4446 cx.update(|cx| {
4447 LanguageModelRegistry::test(cx);
4448 });
4449 cx.update(|cx| {
4450 cx.update_flags(true, vec!["subagents".to_string()]);
4451 });
4452
4453 let fs = FakeFs::new(cx.executor());
4454 fs.insert_tree(
4455 "/",
4456 json!({
4457 "a": {
4458 "b.md": "Lorem"
4459 }
4460 }),
4461 )
4462 .await;
4463 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4464 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4465 let agent = NativeAgent::new(
4466 project.clone(),
4467 thread_store.clone(),
4468 Templates::new(),
4469 None,
4470 fs.clone(),
4471 &mut cx.to_async(),
4472 )
4473 .await
4474 .unwrap();
4475 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4476
4477 let acp_thread = cx
4478 .update(|cx| {
4479 connection
4480 .clone()
4481 .new_session(project.clone(), Path::new(""), cx)
4482 })
4483 .await
4484 .unwrap();
4485 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4486 let thread = agent.read_with(cx, |agent, _| {
4487 agent.sessions.get(&session_id).unwrap().thread.clone()
4488 });
4489 let model = Arc::new(FakeLanguageModel::default());
4490
4491 thread.update(cx, |thread, cx| {
4492 thread.set_model(model.clone(), cx);
4493 });
4494 cx.run_until_parked();
4495
4496 // === First turn: create subagent ===
4497 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx));
4498 cx.run_until_parked();
4499 model.send_last_completion_stream_text_chunk("spawning subagent");
4500 let subagent_tool_input = SpawnAgentToolInput {
4501 label: "initial task".to_string(),
4502 message: "do the first task".to_string(),
4503 session_id: None,
4504 };
4505 let subagent_tool_use = LanguageModelToolUse {
4506 id: "subagent_1".into(),
4507 name: SpawnAgentTool::NAME.into(),
4508 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4509 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4510 is_input_complete: true,
4511 thought_signature: None,
4512 };
4513 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4514 subagent_tool_use,
4515 ));
4516 model.end_last_completion_stream();
4517
4518 cx.run_until_parked();
4519
4520 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4521 thread
4522 .running_subagent_ids(cx)
4523 .get(0)
4524 .expect("subagent thread should be running")
4525 .clone()
4526 });
4527
4528 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4529 agent
4530 .sessions
4531 .get(&subagent_session_id)
4532 .expect("subagent session should exist")
4533 .acp_thread
4534 .clone()
4535 });
4536
4537 // Subagent responds
4538 model.send_last_completion_stream_text_chunk("first task response");
4539 model.end_last_completion_stream();
4540
4541 cx.run_until_parked();
4542
4543 // Parent model responds to complete first turn
4544 model.send_last_completion_stream_text_chunk("First response");
4545 model.end_last_completion_stream();
4546
4547 send.await.unwrap();
4548
4549 // Verify subagent is no longer running
4550 thread.read_with(cx, |thread, cx| {
4551 assert!(
4552 thread.running_subagent_ids(cx).is_empty(),
4553 "subagent should not be running after completion"
4554 );
4555 });
4556
4557 // === Second turn: resume subagent with session_id ===
4558 let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx));
4559 cx.run_until_parked();
4560 model.send_last_completion_stream_text_chunk("resuming subagent");
4561 let resume_tool_input = SpawnAgentToolInput {
4562 label: "follow-up task".to_string(),
4563 message: "do the follow-up task".to_string(),
4564 session_id: Some(subagent_session_id.clone()),
4565 };
4566 let resume_tool_use = LanguageModelToolUse {
4567 id: "subagent_2".into(),
4568 name: SpawnAgentTool::NAME.into(),
4569 raw_input: serde_json::to_string(&resume_tool_input).unwrap(),
4570 input: serde_json::to_value(&resume_tool_input).unwrap(),
4571 is_input_complete: true,
4572 thought_signature: None,
4573 };
4574 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use));
4575 model.end_last_completion_stream();
4576
4577 cx.run_until_parked();
4578
4579 // Subagent should be running again with the same session
4580 thread.read_with(cx, |thread, cx| {
4581 let running = thread.running_subagent_ids(cx);
4582 assert_eq!(running.len(), 1, "subagent should be running");
4583 assert_eq!(running[0], subagent_session_id, "should be same session");
4584 });
4585
4586 // Subagent responds to follow-up
4587 model.send_last_completion_stream_text_chunk("follow-up task response");
4588 model.end_last_completion_stream();
4589
4590 cx.run_until_parked();
4591
4592 // Parent model responds to complete second turn
4593 model.send_last_completion_stream_text_chunk("Second response");
4594 model.end_last_completion_stream();
4595
4596 send2.await.unwrap();
4597
4598 // Verify subagent is no longer running
4599 thread.read_with(cx, |thread, cx| {
4600 assert!(
4601 thread.running_subagent_ids(cx).is_empty(),
4602 "subagent should not be running after resume completion"
4603 );
4604 });
4605
4606 // Verify the subagent's acp thread has both conversation turns
4607 assert_eq!(
4608 subagent_acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4609 indoc! {"
4610 ## User
4611
4612 do the first task
4613
4614 ## Assistant
4615
4616 first task response
4617
4618 ## User
4619
4620 do the follow-up task
4621
4622 ## Assistant
4623
4624 follow-up task response
4625
4626 "}
4627 );
4628}
4629
4630#[gpui::test]
4631async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
4632 init_test(cx);
4633
4634 cx.update(|cx| {
4635 cx.update_flags(true, vec!["subagents".to_string()]);
4636 });
4637
4638 let fs = FakeFs::new(cx.executor());
4639 fs.insert_tree(path!("/test"), json!({})).await;
4640 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4641 let project_context = cx.new(|_cx| ProjectContext::default());
4642 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4643 let context_server_registry =
4644 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4645 let model = Arc::new(FakeLanguageModel::default());
4646
4647 let environment = Rc::new(cx.update(|cx| {
4648 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4649 }));
4650
4651 let thread = cx.new(|cx| {
4652 let mut thread = Thread::new(
4653 project.clone(),
4654 project_context,
4655 context_server_registry,
4656 Templates::new(),
4657 Some(model),
4658 cx,
4659 );
4660 thread.add_default_tools(environment, cx);
4661 thread
4662 });
4663
4664 thread.read_with(cx, |thread, _| {
4665 assert!(
4666 thread.has_registered_tool(SpawnAgentTool::NAME),
4667 "subagent tool should be present when feature flag is enabled"
4668 );
4669 });
4670}
4671
4672#[gpui::test]
4673async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
4674 init_test(cx);
4675
4676 cx.update(|cx| {
4677 cx.update_flags(true, vec!["subagents".to_string()]);
4678 });
4679
4680 let fs = FakeFs::new(cx.executor());
4681 fs.insert_tree(path!("/test"), json!({})).await;
4682 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4683 let project_context = cx.new(|_cx| ProjectContext::default());
4684 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4685 let context_server_registry =
4686 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4687 let model = Arc::new(FakeLanguageModel::default());
4688
4689 let parent_thread = cx.new(|cx| {
4690 Thread::new(
4691 project.clone(),
4692 project_context,
4693 context_server_registry,
4694 Templates::new(),
4695 Some(model.clone()),
4696 cx,
4697 )
4698 });
4699
4700 let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
4701 subagent_thread.read_with(cx, |subagent_thread, cx| {
4702 assert!(subagent_thread.is_subagent());
4703 assert_eq!(subagent_thread.depth(), 1);
4704 assert_eq!(
4705 subagent_thread.model().map(|model| model.id()),
4706 Some(model.id())
4707 );
4708 assert_eq!(
4709 subagent_thread.parent_thread_id(),
4710 Some(parent_thread.read(cx).id().clone())
4711 );
4712 });
4713}
4714
4715#[gpui::test]
4716async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4717 init_test(cx);
4718
4719 cx.update(|cx| {
4720 cx.update_flags(true, vec!["subagents".to_string()]);
4721 });
4722
4723 let fs = FakeFs::new(cx.executor());
4724 fs.insert_tree(path!("/test"), json!({})).await;
4725 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4726 let project_context = cx.new(|_cx| ProjectContext::default());
4727 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4728 let context_server_registry =
4729 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4730 let model = Arc::new(FakeLanguageModel::default());
4731 let environment = Rc::new(cx.update(|cx| {
4732 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4733 }));
4734
4735 let deep_parent_thread = cx.new(|cx| {
4736 let mut thread = Thread::new(
4737 project.clone(),
4738 project_context,
4739 context_server_registry,
4740 Templates::new(),
4741 Some(model.clone()),
4742 cx,
4743 );
4744 thread.set_subagent_context(SubagentContext {
4745 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4746 depth: MAX_SUBAGENT_DEPTH - 1,
4747 });
4748 thread
4749 });
4750 let deep_subagent_thread = cx.new(|cx| {
4751 let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
4752 thread.add_default_tools(environment, cx);
4753 thread
4754 });
4755
4756 deep_subagent_thread.read_with(cx, |thread, _| {
4757 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
4758 assert!(
4759 !thread.has_registered_tool(SpawnAgentTool::NAME),
4760 "subagent tool should not be present at max depth"
4761 );
4762 });
4763}
4764
4765#[gpui::test]
4766async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4767 init_test(cx);
4768
4769 cx.update(|cx| {
4770 cx.update_flags(true, vec!["subagents".to_string()]);
4771 });
4772
4773 let fs = FakeFs::new(cx.executor());
4774 fs.insert_tree(path!("/test"), json!({})).await;
4775 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4776 let project_context = cx.new(|_cx| ProjectContext::default());
4777 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4778 let context_server_registry =
4779 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4780 let model = Arc::new(FakeLanguageModel::default());
4781
4782 let parent = cx.new(|cx| {
4783 Thread::new(
4784 project.clone(),
4785 project_context.clone(),
4786 context_server_registry.clone(),
4787 Templates::new(),
4788 Some(model.clone()),
4789 cx,
4790 )
4791 });
4792
4793 let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
4794
4795 parent.update(cx, |thread, _cx| {
4796 thread.register_running_subagent(subagent.downgrade());
4797 });
4798
4799 subagent
4800 .update(cx, |thread, cx| {
4801 thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
4802 })
4803 .unwrap();
4804 cx.run_until_parked();
4805
4806 subagent.read_with(cx, |thread, _| {
4807 assert!(!thread.is_turn_complete(), "subagent should be running");
4808 });
4809
4810 parent.update(cx, |thread, cx| {
4811 thread.cancel(cx).detach();
4812 });
4813
4814 subagent.read_with(cx, |thread, _| {
4815 assert!(
4816 thread.is_turn_complete(),
4817 "subagent should be cancelled when parent cancels"
4818 );
4819 });
4820}
4821
4822#[gpui::test]
4823async fn test_subagent_context_window_warning(cx: &mut TestAppContext) {
4824 init_test(cx);
4825 cx.update(|cx| {
4826 LanguageModelRegistry::test(cx);
4827 });
4828 cx.update(|cx| {
4829 cx.update_flags(true, vec!["subagents".to_string()]);
4830 });
4831
4832 let fs = FakeFs::new(cx.executor());
4833 fs.insert_tree(
4834 "/",
4835 json!({
4836 "a": {
4837 "b.md": "Lorem"
4838 }
4839 }),
4840 )
4841 .await;
4842 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4843 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4844 let agent = NativeAgent::new(
4845 project.clone(),
4846 thread_store.clone(),
4847 Templates::new(),
4848 None,
4849 fs.clone(),
4850 &mut cx.to_async(),
4851 )
4852 .await
4853 .unwrap();
4854 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4855
4856 let acp_thread = cx
4857 .update(|cx| {
4858 connection
4859 .clone()
4860 .new_session(project.clone(), Path::new(""), cx)
4861 })
4862 .await
4863 .unwrap();
4864 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4865 let thread = agent.read_with(cx, |agent, _| {
4866 agent.sessions.get(&session_id).unwrap().thread.clone()
4867 });
4868 let model = Arc::new(FakeLanguageModel::default());
4869
4870 thread.update(cx, |thread, cx| {
4871 thread.set_model(model.clone(), cx);
4872 });
4873 cx.run_until_parked();
4874
4875 // Start the parent turn
4876 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4877 cx.run_until_parked();
4878 model.send_last_completion_stream_text_chunk("spawning subagent");
4879 let subagent_tool_input = SpawnAgentToolInput {
4880 label: "label".to_string(),
4881 message: "subagent task prompt".to_string(),
4882 session_id: None,
4883 };
4884 let subagent_tool_use = LanguageModelToolUse {
4885 id: "subagent_1".into(),
4886 name: SpawnAgentTool::NAME.into(),
4887 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4888 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4889 is_input_complete: true,
4890 thought_signature: None,
4891 };
4892 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4893 subagent_tool_use,
4894 ));
4895 model.end_last_completion_stream();
4896
4897 cx.run_until_parked();
4898
4899 // Verify subagent is running
4900 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4901 thread
4902 .running_subagent_ids(cx)
4903 .get(0)
4904 .expect("subagent thread should be running")
4905 .clone()
4906 });
4907
4908 // Send a usage update that crosses the warning threshold (80% of 1,000,000)
4909 model.send_last_completion_stream_text_chunk("partial work");
4910 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
4911 TokenUsage {
4912 input_tokens: 850_000,
4913 output_tokens: 0,
4914 cache_creation_input_tokens: 0,
4915 cache_read_input_tokens: 0,
4916 },
4917 ));
4918
4919 cx.run_until_parked();
4920
4921 // The subagent should no longer be running
4922 thread.read_with(cx, |thread, cx| {
4923 assert!(
4924 thread.running_subagent_ids(cx).is_empty(),
4925 "subagent should be stopped after context window warning"
4926 );
4927 });
4928
4929 // The parent model should get a new completion request to respond to the tool error
4930 model.send_last_completion_stream_text_chunk("Response after warning");
4931 model.end_last_completion_stream();
4932
4933 send.await.unwrap();
4934
4935 // Verify the parent thread shows the warning error in the tool call
4936 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
4937 assert!(
4938 markdown.contains("nearing the end of its context window"),
4939 "tool output should contain context window warning message, got:\n{markdown}"
4940 );
4941 assert!(
4942 markdown.contains("Status: Failed"),
4943 "tool call should have Failed status, got:\n{markdown}"
4944 );
4945
4946 // Verify the subagent session still exists (can be resumed)
4947 agent.read_with(cx, |agent, _cx| {
4948 assert!(
4949 agent.sessions.contains_key(&subagent_session_id),
4950 "subagent session should still exist for potential resume"
4951 );
4952 });
4953}
4954
4955#[gpui::test]
4956async fn test_subagent_no_context_window_warning_when_already_at_warning(cx: &mut TestAppContext) {
4957 init_test(cx);
4958 cx.update(|cx| {
4959 LanguageModelRegistry::test(cx);
4960 });
4961 cx.update(|cx| {
4962 cx.update_flags(true, vec!["subagents".to_string()]);
4963 });
4964
4965 let fs = FakeFs::new(cx.executor());
4966 fs.insert_tree(
4967 "/",
4968 json!({
4969 "a": {
4970 "b.md": "Lorem"
4971 }
4972 }),
4973 )
4974 .await;
4975 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4976 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4977 let agent = NativeAgent::new(
4978 project.clone(),
4979 thread_store.clone(),
4980 Templates::new(),
4981 None,
4982 fs.clone(),
4983 &mut cx.to_async(),
4984 )
4985 .await
4986 .unwrap();
4987 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4988
4989 let acp_thread = cx
4990 .update(|cx| {
4991 connection
4992 .clone()
4993 .new_session(project.clone(), Path::new(""), cx)
4994 })
4995 .await
4996 .unwrap();
4997 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4998 let thread = agent.read_with(cx, |agent, _| {
4999 agent.sessions.get(&session_id).unwrap().thread.clone()
5000 });
5001 let model = Arc::new(FakeLanguageModel::default());
5002
5003 thread.update(cx, |thread, cx| {
5004 thread.set_model(model.clone(), cx);
5005 });
5006 cx.run_until_parked();
5007
5008 // === First turn: create subagent, trigger context window warning ===
5009 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx));
5010 cx.run_until_parked();
5011 model.send_last_completion_stream_text_chunk("spawning subagent");
5012 let subagent_tool_input = SpawnAgentToolInput {
5013 label: "initial task".to_string(),
5014 message: "do the first task".to_string(),
5015 session_id: None,
5016 };
5017 let subagent_tool_use = LanguageModelToolUse {
5018 id: "subagent_1".into(),
5019 name: SpawnAgentTool::NAME.into(),
5020 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5021 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5022 is_input_complete: true,
5023 thought_signature: None,
5024 };
5025 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5026 subagent_tool_use,
5027 ));
5028 model.end_last_completion_stream();
5029
5030 cx.run_until_parked();
5031
5032 let subagent_session_id = thread.read_with(cx, |thread, cx| {
5033 thread
5034 .running_subagent_ids(cx)
5035 .get(0)
5036 .expect("subagent thread should be running")
5037 .clone()
5038 });
5039
5040 // Subagent sends a usage update that crosses the warning threshold.
5041 // This triggers Normal→Warning, stopping the subagent.
5042 model.send_last_completion_stream_text_chunk("partial work");
5043 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5044 TokenUsage {
5045 input_tokens: 850_000,
5046 output_tokens: 0,
5047 cache_creation_input_tokens: 0,
5048 cache_read_input_tokens: 0,
5049 },
5050 ));
5051
5052 cx.run_until_parked();
5053
5054 // Verify the first turn was stopped with a context window warning
5055 thread.read_with(cx, |thread, cx| {
5056 assert!(
5057 thread.running_subagent_ids(cx).is_empty(),
5058 "subagent should be stopped after context window warning"
5059 );
5060 });
5061
5062 // Parent model responds to complete first turn
5063 model.send_last_completion_stream_text_chunk("First response");
5064 model.end_last_completion_stream();
5065
5066 send.await.unwrap();
5067
5068 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5069 assert!(
5070 markdown.contains("nearing the end of its context window"),
5071 "first turn should have context window warning, got:\n{markdown}"
5072 );
5073
5074 // === Second turn: resume the same subagent (now at Warning level) ===
5075 let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx));
5076 cx.run_until_parked();
5077 model.send_last_completion_stream_text_chunk("resuming subagent");
5078 let resume_tool_input = SpawnAgentToolInput {
5079 label: "follow-up task".to_string(),
5080 message: "do the follow-up task".to_string(),
5081 session_id: Some(subagent_session_id.clone()),
5082 };
5083 let resume_tool_use = LanguageModelToolUse {
5084 id: "subagent_2".into(),
5085 name: SpawnAgentTool::NAME.into(),
5086 raw_input: serde_json::to_string(&resume_tool_input).unwrap(),
5087 input: serde_json::to_value(&resume_tool_input).unwrap(),
5088 is_input_complete: true,
5089 thought_signature: None,
5090 };
5091 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use));
5092 model.end_last_completion_stream();
5093
5094 cx.run_until_parked();
5095
5096 // Subagent responds with tokens still at warning level (no worse).
5097 // Since ratio_before_prompt was already Warning, this should NOT
5098 // trigger the context window warning again.
5099 model.send_last_completion_stream_text_chunk("follow-up task response");
5100 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5101 TokenUsage {
5102 input_tokens: 870_000,
5103 output_tokens: 0,
5104 cache_creation_input_tokens: 0,
5105 cache_read_input_tokens: 0,
5106 },
5107 ));
5108 model.end_last_completion_stream();
5109
5110 cx.run_until_parked();
5111
5112 // Parent model responds to complete second turn
5113 model.send_last_completion_stream_text_chunk("Second response");
5114 model.end_last_completion_stream();
5115
5116 send2.await.unwrap();
5117
5118 // The resumed subagent should have completed normally since the ratio
5119 // didn't transition (it was Warning before and stayed at Warning)
5120 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5121 assert!(
5122 markdown.contains("follow-up task response"),
5123 "resumed subagent should complete normally when already at warning, got:\n{markdown}"
5124 );
5125 // The second tool call should NOT have a context window warning
5126 let second_tool_pos = markdown
5127 .find("follow-up task")
5128 .expect("should find follow-up tool call");
5129 let after_second_tool = &markdown[second_tool_pos..];
5130 assert!(
5131 !after_second_tool.contains("nearing the end of its context window"),
5132 "should NOT contain context window warning for resumed subagent at same level, got:\n{after_second_tool}"
5133 );
5134}
5135
5136#[gpui::test]
5137async fn test_subagent_error_propagation(cx: &mut TestAppContext) {
5138 init_test(cx);
5139 cx.update(|cx| {
5140 LanguageModelRegistry::test(cx);
5141 });
5142 cx.update(|cx| {
5143 cx.update_flags(true, vec!["subagents".to_string()]);
5144 });
5145
5146 let fs = FakeFs::new(cx.executor());
5147 fs.insert_tree(
5148 "/",
5149 json!({
5150 "a": {
5151 "b.md": "Lorem"
5152 }
5153 }),
5154 )
5155 .await;
5156 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
5157 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5158 let agent = NativeAgent::new(
5159 project.clone(),
5160 thread_store.clone(),
5161 Templates::new(),
5162 None,
5163 fs.clone(),
5164 &mut cx.to_async(),
5165 )
5166 .await
5167 .unwrap();
5168 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5169
5170 let acp_thread = cx
5171 .update(|cx| {
5172 connection
5173 .clone()
5174 .new_session(project.clone(), Path::new(""), cx)
5175 })
5176 .await
5177 .unwrap();
5178 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5179 let thread = agent.read_with(cx, |agent, _| {
5180 agent.sessions.get(&session_id).unwrap().thread.clone()
5181 });
5182 let model = Arc::new(FakeLanguageModel::default());
5183
5184 thread.update(cx, |thread, cx| {
5185 thread.set_model(model.clone(), cx);
5186 });
5187 cx.run_until_parked();
5188
5189 // Start the parent turn
5190 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
5191 cx.run_until_parked();
5192 model.send_last_completion_stream_text_chunk("spawning subagent");
5193 let subagent_tool_input = SpawnAgentToolInput {
5194 label: "label".to_string(),
5195 message: "subagent task prompt".to_string(),
5196 session_id: None,
5197 };
5198 let subagent_tool_use = LanguageModelToolUse {
5199 id: "subagent_1".into(),
5200 name: SpawnAgentTool::NAME.into(),
5201 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5202 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5203 is_input_complete: true,
5204 thought_signature: None,
5205 };
5206 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5207 subagent_tool_use,
5208 ));
5209 model.end_last_completion_stream();
5210
5211 cx.run_until_parked();
5212
5213 // Verify subagent is running
5214 thread.read_with(cx, |thread, cx| {
5215 assert!(
5216 !thread.running_subagent_ids(cx).is_empty(),
5217 "subagent should be running"
5218 );
5219 });
5220
5221 // The subagent's model returns a non-retryable error
5222 model.send_last_completion_stream_error(LanguageModelCompletionError::PromptTooLarge {
5223 tokens: None,
5224 });
5225
5226 cx.run_until_parked();
5227
5228 // The subagent should no longer be running
5229 thread.read_with(cx, |thread, cx| {
5230 assert!(
5231 thread.running_subagent_ids(cx).is_empty(),
5232 "subagent should not be running after error"
5233 );
5234 });
5235
5236 // The parent model should get a new completion request to respond to the tool error
5237 model.send_last_completion_stream_text_chunk("Response after error");
5238 model.end_last_completion_stream();
5239
5240 send.await.unwrap();
5241
5242 // Verify the parent thread shows the error in the tool call
5243 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5244 assert!(
5245 markdown.contains("Status: Failed"),
5246 "tool call should have Failed status after model error, got:\n{markdown}"
5247 );
5248}
5249
5250#[gpui::test]
5251async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5252 init_test(cx);
5253
5254 let fs = FakeFs::new(cx.executor());
5255 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5256 .await;
5257 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5258
5259 cx.update(|cx| {
5260 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5261 settings.tool_permissions.tools.insert(
5262 EditFileTool::NAME.into(),
5263 agent_settings::ToolRules {
5264 default: Some(settings::ToolPermissionMode::Allow),
5265 always_allow: vec![],
5266 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5267 always_confirm: vec![],
5268 invalid_patterns: vec![],
5269 },
5270 );
5271 agent_settings::AgentSettings::override_global(settings, cx);
5272 });
5273
5274 let context_server_registry =
5275 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5276 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5277 let templates = crate::Templates::new();
5278 let thread = cx.new(|cx| {
5279 crate::Thread::new(
5280 project.clone(),
5281 cx.new(|_cx| prompt_store::ProjectContext::default()),
5282 context_server_registry,
5283 templates.clone(),
5284 None,
5285 cx,
5286 )
5287 });
5288
5289 #[allow(clippy::arc_with_non_send_sync)]
5290 let tool = Arc::new(crate::EditFileTool::new(
5291 project.clone(),
5292 thread.downgrade(),
5293 language_registry,
5294 templates,
5295 ));
5296 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5297
5298 let task = cx.update(|cx| {
5299 tool.run(
5300 ToolInput::resolved(crate::EditFileToolInput {
5301 display_description: "Edit sensitive file".to_string(),
5302 path: "root/sensitive_config.txt".into(),
5303 mode: crate::EditFileMode::Edit,
5304 }),
5305 event_stream,
5306 cx,
5307 )
5308 });
5309
5310 let result = task.await;
5311 assert!(result.is_err(), "expected edit to be blocked");
5312 assert!(
5313 result.unwrap_err().to_string().contains("blocked"),
5314 "error should mention the edit was blocked"
5315 );
5316}
5317
5318#[gpui::test]
5319async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5320 init_test(cx);
5321
5322 let fs = FakeFs::new(cx.executor());
5323 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5324 .await;
5325 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5326
5327 cx.update(|cx| {
5328 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5329 settings.tool_permissions.tools.insert(
5330 DeletePathTool::NAME.into(),
5331 agent_settings::ToolRules {
5332 default: Some(settings::ToolPermissionMode::Allow),
5333 always_allow: vec![],
5334 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5335 always_confirm: vec![],
5336 invalid_patterns: vec![],
5337 },
5338 );
5339 agent_settings::AgentSettings::override_global(settings, cx);
5340 });
5341
5342 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5343
5344 #[allow(clippy::arc_with_non_send_sync)]
5345 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5346 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5347
5348 let task = cx.update(|cx| {
5349 tool.run(
5350 ToolInput::resolved(crate::DeletePathToolInput {
5351 path: "root/important_data.txt".to_string(),
5352 }),
5353 event_stream,
5354 cx,
5355 )
5356 });
5357
5358 let result = task.await;
5359 assert!(result.is_err(), "expected deletion to be blocked");
5360 assert!(
5361 result.unwrap_err().contains("blocked"),
5362 "error should mention the deletion was blocked"
5363 );
5364}
5365
5366#[gpui::test]
5367async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5368 init_test(cx);
5369
5370 let fs = FakeFs::new(cx.executor());
5371 fs.insert_tree(
5372 "/root",
5373 json!({
5374 "safe.txt": "content",
5375 "protected": {}
5376 }),
5377 )
5378 .await;
5379 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5380
5381 cx.update(|cx| {
5382 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5383 settings.tool_permissions.tools.insert(
5384 MovePathTool::NAME.into(),
5385 agent_settings::ToolRules {
5386 default: Some(settings::ToolPermissionMode::Allow),
5387 always_allow: vec![],
5388 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5389 always_confirm: vec![],
5390 invalid_patterns: vec![],
5391 },
5392 );
5393 agent_settings::AgentSettings::override_global(settings, cx);
5394 });
5395
5396 #[allow(clippy::arc_with_non_send_sync)]
5397 let tool = Arc::new(crate::MovePathTool::new(project));
5398 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5399
5400 let task = cx.update(|cx| {
5401 tool.run(
5402 ToolInput::resolved(crate::MovePathToolInput {
5403 source_path: "root/safe.txt".to_string(),
5404 destination_path: "root/protected/safe.txt".to_string(),
5405 }),
5406 event_stream,
5407 cx,
5408 )
5409 });
5410
5411 let result = task.await;
5412 assert!(
5413 result.is_err(),
5414 "expected move to be blocked due to destination path"
5415 );
5416 assert!(
5417 result.unwrap_err().contains("blocked"),
5418 "error should mention the move was blocked"
5419 );
5420}
5421
5422#[gpui::test]
5423async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5424 init_test(cx);
5425
5426 let fs = FakeFs::new(cx.executor());
5427 fs.insert_tree(
5428 "/root",
5429 json!({
5430 "secret.txt": "secret content",
5431 "public": {}
5432 }),
5433 )
5434 .await;
5435 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5436
5437 cx.update(|cx| {
5438 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5439 settings.tool_permissions.tools.insert(
5440 MovePathTool::NAME.into(),
5441 agent_settings::ToolRules {
5442 default: Some(settings::ToolPermissionMode::Allow),
5443 always_allow: vec![],
5444 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5445 always_confirm: vec![],
5446 invalid_patterns: vec![],
5447 },
5448 );
5449 agent_settings::AgentSettings::override_global(settings, cx);
5450 });
5451
5452 #[allow(clippy::arc_with_non_send_sync)]
5453 let tool = Arc::new(crate::MovePathTool::new(project));
5454 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5455
5456 let task = cx.update(|cx| {
5457 tool.run(
5458 ToolInput::resolved(crate::MovePathToolInput {
5459 source_path: "root/secret.txt".to_string(),
5460 destination_path: "root/public/not_secret.txt".to_string(),
5461 }),
5462 event_stream,
5463 cx,
5464 )
5465 });
5466
5467 let result = task.await;
5468 assert!(
5469 result.is_err(),
5470 "expected move to be blocked due to source path"
5471 );
5472 assert!(
5473 result.unwrap_err().contains("blocked"),
5474 "error should mention the move was blocked"
5475 );
5476}
5477
5478#[gpui::test]
5479async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5480 init_test(cx);
5481
5482 let fs = FakeFs::new(cx.executor());
5483 fs.insert_tree(
5484 "/root",
5485 json!({
5486 "confidential.txt": "confidential data",
5487 "dest": {}
5488 }),
5489 )
5490 .await;
5491 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5492
5493 cx.update(|cx| {
5494 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5495 settings.tool_permissions.tools.insert(
5496 CopyPathTool::NAME.into(),
5497 agent_settings::ToolRules {
5498 default: Some(settings::ToolPermissionMode::Allow),
5499 always_allow: vec![],
5500 always_deny: vec![
5501 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5502 ],
5503 always_confirm: vec![],
5504 invalid_patterns: vec![],
5505 },
5506 );
5507 agent_settings::AgentSettings::override_global(settings, cx);
5508 });
5509
5510 #[allow(clippy::arc_with_non_send_sync)]
5511 let tool = Arc::new(crate::CopyPathTool::new(project));
5512 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5513
5514 let task = cx.update(|cx| {
5515 tool.run(
5516 ToolInput::resolved(crate::CopyPathToolInput {
5517 source_path: "root/confidential.txt".to_string(),
5518 destination_path: "root/dest/copy.txt".to_string(),
5519 }),
5520 event_stream,
5521 cx,
5522 )
5523 });
5524
5525 let result = task.await;
5526 assert!(result.is_err(), "expected copy to be blocked");
5527 assert!(
5528 result.unwrap_err().contains("blocked"),
5529 "error should mention the copy was blocked"
5530 );
5531}
5532
5533#[gpui::test]
5534async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5535 init_test(cx);
5536
5537 let fs = FakeFs::new(cx.executor());
5538 fs.insert_tree(
5539 "/root",
5540 json!({
5541 "normal.txt": "normal content",
5542 "readonly": {
5543 "config.txt": "readonly content"
5544 }
5545 }),
5546 )
5547 .await;
5548 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5549
5550 cx.update(|cx| {
5551 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5552 settings.tool_permissions.tools.insert(
5553 SaveFileTool::NAME.into(),
5554 agent_settings::ToolRules {
5555 default: Some(settings::ToolPermissionMode::Allow),
5556 always_allow: vec![],
5557 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5558 always_confirm: vec![],
5559 invalid_patterns: vec![],
5560 },
5561 );
5562 agent_settings::AgentSettings::override_global(settings, cx);
5563 });
5564
5565 #[allow(clippy::arc_with_non_send_sync)]
5566 let tool = Arc::new(crate::SaveFileTool::new(project));
5567 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5568
5569 let task = cx.update(|cx| {
5570 tool.run(
5571 ToolInput::resolved(crate::SaveFileToolInput {
5572 paths: vec![
5573 std::path::PathBuf::from("root/normal.txt"),
5574 std::path::PathBuf::from("root/readonly/config.txt"),
5575 ],
5576 }),
5577 event_stream,
5578 cx,
5579 )
5580 });
5581
5582 let result = task.await;
5583 assert!(
5584 result.is_err(),
5585 "expected save to be blocked due to denied path"
5586 );
5587 assert!(
5588 result.unwrap_err().contains("blocked"),
5589 "error should mention the save was blocked"
5590 );
5591}
5592
5593#[gpui::test]
5594async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5595 init_test(cx);
5596
5597 let fs = FakeFs::new(cx.executor());
5598 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5599 .await;
5600 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5601
5602 cx.update(|cx| {
5603 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5604 settings.tool_permissions.tools.insert(
5605 SaveFileTool::NAME.into(),
5606 agent_settings::ToolRules {
5607 default: Some(settings::ToolPermissionMode::Allow),
5608 always_allow: vec![],
5609 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5610 always_confirm: vec![],
5611 invalid_patterns: vec![],
5612 },
5613 );
5614 agent_settings::AgentSettings::override_global(settings, cx);
5615 });
5616
5617 #[allow(clippy::arc_with_non_send_sync)]
5618 let tool = Arc::new(crate::SaveFileTool::new(project));
5619 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5620
5621 let task = cx.update(|cx| {
5622 tool.run(
5623 ToolInput::resolved(crate::SaveFileToolInput {
5624 paths: vec![std::path::PathBuf::from("root/config.secret")],
5625 }),
5626 event_stream,
5627 cx,
5628 )
5629 });
5630
5631 let result = task.await;
5632 assert!(result.is_err(), "expected save to be blocked");
5633 assert!(
5634 result.unwrap_err().contains("blocked"),
5635 "error should mention the save was blocked"
5636 );
5637}
5638
5639#[gpui::test]
5640async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5641 init_test(cx);
5642
5643 cx.update(|cx| {
5644 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5645 settings.tool_permissions.tools.insert(
5646 WebSearchTool::NAME.into(),
5647 agent_settings::ToolRules {
5648 default: Some(settings::ToolPermissionMode::Allow),
5649 always_allow: vec![],
5650 always_deny: vec![
5651 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5652 ],
5653 always_confirm: vec![],
5654 invalid_patterns: vec![],
5655 },
5656 );
5657 agent_settings::AgentSettings::override_global(settings, cx);
5658 });
5659
5660 #[allow(clippy::arc_with_non_send_sync)]
5661 let tool = Arc::new(crate::WebSearchTool);
5662 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5663
5664 let input: crate::WebSearchToolInput =
5665 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5666
5667 let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
5668
5669 let result = task.await;
5670 assert!(result.is_err(), "expected search to be blocked");
5671 match result.unwrap_err() {
5672 crate::WebSearchToolOutput::Error { error } => {
5673 assert!(
5674 error.contains("blocked"),
5675 "error should mention the search was blocked"
5676 );
5677 }
5678 other => panic!("expected Error variant, got: {other:?}"),
5679 }
5680}
5681
5682#[gpui::test]
5683async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5684 init_test(cx);
5685
5686 let fs = FakeFs::new(cx.executor());
5687 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5688 .await;
5689 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5690
5691 cx.update(|cx| {
5692 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5693 settings.tool_permissions.tools.insert(
5694 EditFileTool::NAME.into(),
5695 agent_settings::ToolRules {
5696 default: Some(settings::ToolPermissionMode::Confirm),
5697 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5698 always_deny: vec![],
5699 always_confirm: vec![],
5700 invalid_patterns: vec![],
5701 },
5702 );
5703 agent_settings::AgentSettings::override_global(settings, cx);
5704 });
5705
5706 let context_server_registry =
5707 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5708 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5709 let templates = crate::Templates::new();
5710 let thread = cx.new(|cx| {
5711 crate::Thread::new(
5712 project.clone(),
5713 cx.new(|_cx| prompt_store::ProjectContext::default()),
5714 context_server_registry,
5715 templates.clone(),
5716 None,
5717 cx,
5718 )
5719 });
5720
5721 #[allow(clippy::arc_with_non_send_sync)]
5722 let tool = Arc::new(crate::EditFileTool::new(
5723 project,
5724 thread.downgrade(),
5725 language_registry,
5726 templates,
5727 ));
5728 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5729
5730 let _task = cx.update(|cx| {
5731 tool.run(
5732 ToolInput::resolved(crate::EditFileToolInput {
5733 display_description: "Edit README".to_string(),
5734 path: "root/README.md".into(),
5735 mode: crate::EditFileMode::Edit,
5736 }),
5737 event_stream,
5738 cx,
5739 )
5740 });
5741
5742 cx.run_until_parked();
5743
5744 let event = rx.try_next();
5745 assert!(
5746 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5747 "expected no authorization request for allowed .md file"
5748 );
5749}
5750
5751#[gpui::test]
5752async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut TestAppContext) {
5753 init_test(cx);
5754
5755 let fs = FakeFs::new(cx.executor());
5756 fs.insert_tree(
5757 "/root",
5758 json!({
5759 ".zed": {
5760 "settings.json": "{}"
5761 },
5762 "README.md": "# Hello"
5763 }),
5764 )
5765 .await;
5766 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5767
5768 cx.update(|cx| {
5769 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5770 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
5771 agent_settings::AgentSettings::override_global(settings, cx);
5772 });
5773
5774 let context_server_registry =
5775 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5776 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5777 let templates = crate::Templates::new();
5778 let thread = cx.new(|cx| {
5779 crate::Thread::new(
5780 project.clone(),
5781 cx.new(|_cx| prompt_store::ProjectContext::default()),
5782 context_server_registry,
5783 templates.clone(),
5784 None,
5785 cx,
5786 )
5787 });
5788
5789 #[allow(clippy::arc_with_non_send_sync)]
5790 let tool = Arc::new(crate::EditFileTool::new(
5791 project,
5792 thread.downgrade(),
5793 language_registry,
5794 templates,
5795 ));
5796
5797 // Editing a file inside .zed/ should still prompt even with global default: allow,
5798 // because local settings paths are sensitive and require confirmation regardless.
5799 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5800 let _task = cx.update(|cx| {
5801 tool.run(
5802 ToolInput::resolved(crate::EditFileToolInput {
5803 display_description: "Edit local settings".to_string(),
5804 path: "root/.zed/settings.json".into(),
5805 mode: crate::EditFileMode::Edit,
5806 }),
5807 event_stream,
5808 cx,
5809 )
5810 });
5811
5812 let _update = rx.expect_update_fields().await;
5813 let _auth = rx.expect_authorization().await;
5814}
5815
5816#[gpui::test]
5817async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5818 init_test(cx);
5819
5820 cx.update(|cx| {
5821 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5822 settings.tool_permissions.tools.insert(
5823 FetchTool::NAME.into(),
5824 agent_settings::ToolRules {
5825 default: Some(settings::ToolPermissionMode::Allow),
5826 always_allow: vec![],
5827 always_deny: vec![
5828 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5829 ],
5830 always_confirm: vec![],
5831 invalid_patterns: vec![],
5832 },
5833 );
5834 agent_settings::AgentSettings::override_global(settings, cx);
5835 });
5836
5837 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5838
5839 #[allow(clippy::arc_with_non_send_sync)]
5840 let tool = Arc::new(crate::FetchTool::new(http_client));
5841 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5842
5843 let input: crate::FetchToolInput =
5844 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5845
5846 let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
5847
5848 let result = task.await;
5849 assert!(result.is_err(), "expected fetch to be blocked");
5850 assert!(
5851 result.unwrap_err().contains("blocked"),
5852 "error should mention the fetch was blocked"
5853 );
5854}
5855
5856#[gpui::test]
5857async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5858 init_test(cx);
5859
5860 cx.update(|cx| {
5861 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5862 settings.tool_permissions.tools.insert(
5863 FetchTool::NAME.into(),
5864 agent_settings::ToolRules {
5865 default: Some(settings::ToolPermissionMode::Confirm),
5866 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5867 always_deny: vec![],
5868 always_confirm: vec![],
5869 invalid_patterns: vec![],
5870 },
5871 );
5872 agent_settings::AgentSettings::override_global(settings, cx);
5873 });
5874
5875 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5876
5877 #[allow(clippy::arc_with_non_send_sync)]
5878 let tool = Arc::new(crate::FetchTool::new(http_client));
5879 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5880
5881 let input: crate::FetchToolInput =
5882 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5883
5884 let _task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
5885
5886 cx.run_until_parked();
5887
5888 let event = rx.try_next();
5889 assert!(
5890 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5891 "expected no authorization request for allowed docs.rs URL"
5892 );
5893}
5894
5895#[gpui::test]
5896async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5897 init_test(cx);
5898 always_allow_tools(cx);
5899
5900 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5901 let fake_model = model.as_fake();
5902
5903 // Add a tool so we can simulate tool calls
5904 thread.update(cx, |thread, _cx| {
5905 thread.add_tool(EchoTool);
5906 });
5907
5908 // Start a turn by sending a message
5909 let mut events = thread
5910 .update(cx, |thread, cx| {
5911 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5912 })
5913 .unwrap();
5914 cx.run_until_parked();
5915
5916 // Simulate the model making a tool call
5917 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5918 LanguageModelToolUse {
5919 id: "tool_1".into(),
5920 name: "echo".into(),
5921 raw_input: r#"{"text": "hello"}"#.into(),
5922 input: json!({"text": "hello"}),
5923 is_input_complete: true,
5924 thought_signature: None,
5925 },
5926 ));
5927 fake_model
5928 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5929
5930 // Signal that a message is queued before ending the stream
5931 thread.update(cx, |thread, _cx| {
5932 thread.set_has_queued_message(true);
5933 });
5934
5935 // Now end the stream - tool will run, and the boundary check should see the queue
5936 fake_model.end_last_completion_stream();
5937
5938 // Collect all events until the turn stops
5939 let all_events = collect_events_until_stop(&mut events, cx).await;
5940
5941 // Verify we received the tool call event
5942 let tool_call_ids: Vec<_> = all_events
5943 .iter()
5944 .filter_map(|e| match e {
5945 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5946 _ => None,
5947 })
5948 .collect();
5949 assert_eq!(
5950 tool_call_ids,
5951 vec!["tool_1"],
5952 "Should have received a tool call event for our echo tool"
5953 );
5954
5955 // The turn should have stopped with EndTurn
5956 let stop_reasons = stop_events(all_events);
5957 assert_eq!(
5958 stop_reasons,
5959 vec![acp::StopReason::EndTurn],
5960 "Turn should have ended after tool completion due to queued message"
5961 );
5962
5963 // Verify the queued message flag is still set
5964 thread.update(cx, |thread, _cx| {
5965 assert!(
5966 thread.has_queued_message(),
5967 "Should still have queued message flag set"
5968 );
5969 });
5970
5971 // Thread should be idle now
5972 thread.update(cx, |thread, _cx| {
5973 assert!(
5974 thread.is_turn_complete(),
5975 "Thread should not be running after turn ends"
5976 );
5977 });
5978}