1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, ThreadStatus,
4 UserMessageId,
5};
6use agent_client_protocol::{self as acp};
7use agent_settings::AgentProfileId;
8use anyhow::Result;
9use client::{Client, UserStore};
10use cloud_llm_client::CompletionIntent;
11use collections::IndexMap;
12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
13use feature_flags::FeatureFlagAppExt as _;
14use fs::{FakeFs, Fs};
15use futures::{
16 FutureExt as _, StreamExt,
17 channel::{
18 mpsc::{self, UnboundedReceiver},
19 oneshot,
20 },
21 future::{Fuse, Shared},
22};
23use gpui::{
24 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
25 http_client::FakeHttpClient,
26};
27use indoc::indoc;
28use language_model::{
29 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
30 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
31 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
32 LanguageModelToolUse, MessageContent, Role, StopReason, TokenUsage,
33 fake_provider::FakeLanguageModel,
34};
35use pretty_assertions::assert_eq;
36use project::{
37 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
38};
39use prompt_store::ProjectContext;
40use reqwest_client::ReqwestClient;
41use schemars::JsonSchema;
42use serde::{Deserialize, Serialize};
43use serde_json::json;
44use settings::{Settings, SettingsStore};
45use std::{
46 path::Path,
47 pin::Pin,
48 rc::Rc,
49 sync::{
50 Arc,
51 atomic::{AtomicBool, AtomicUsize, Ordering},
52 },
53 time::Duration,
54};
55use util::path;
56
57mod edit_file_thread_test;
58mod test_tools;
59use test_tools::*;
60
61pub(crate) fn init_test(cx: &mut TestAppContext) {
62 cx.update(|cx| {
63 let settings_store = SettingsStore::test(cx);
64 cx.set_global(settings_store);
65 });
66}
67
68pub(crate) struct FakeTerminalHandle {
69 killed: Arc<AtomicBool>,
70 stopped_by_user: Arc<AtomicBool>,
71 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
72 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
73 output: acp::TerminalOutputResponse,
74 id: acp::TerminalId,
75}
76
77impl FakeTerminalHandle {
78 pub(crate) fn new_never_exits(cx: &mut App) -> Self {
79 let killed = Arc::new(AtomicBool::new(false));
80 let stopped_by_user = Arc::new(AtomicBool::new(false));
81
82 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
83
84 let wait_for_exit = cx
85 .spawn(async move |_cx| {
86 // Wait for the exit signal (sent when kill() is called)
87 let _ = exit_receiver.await;
88 acp::TerminalExitStatus::new()
89 })
90 .shared();
91
92 Self {
93 killed,
94 stopped_by_user,
95 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
96 wait_for_exit,
97 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
98 id: acp::TerminalId::new("fake_terminal".to_string()),
99 }
100 }
101
102 pub(crate) fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
103 let killed = Arc::new(AtomicBool::new(false));
104 let stopped_by_user = Arc::new(AtomicBool::new(false));
105 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
106
107 let wait_for_exit = cx
108 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
109 .shared();
110
111 Self {
112 killed,
113 stopped_by_user,
114 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
115 wait_for_exit,
116 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
117 id: acp::TerminalId::new("fake_terminal".to_string()),
118 }
119 }
120
121 pub(crate) fn was_killed(&self) -> bool {
122 self.killed.load(Ordering::SeqCst)
123 }
124
125 pub(crate) fn set_stopped_by_user(&self, stopped: bool) {
126 self.stopped_by_user.store(stopped, Ordering::SeqCst);
127 }
128
129 pub(crate) fn signal_exit(&self) {
130 if let Some(sender) = self.exit_sender.borrow_mut().take() {
131 let _ = sender.send(());
132 }
133 }
134}
135
136impl crate::TerminalHandle for FakeTerminalHandle {
137 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
138 Ok(self.id.clone())
139 }
140
141 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
142 Ok(self.output.clone())
143 }
144
145 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
146 Ok(self.wait_for_exit.clone())
147 }
148
149 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
150 self.killed.store(true, Ordering::SeqCst);
151 self.signal_exit();
152 Ok(())
153 }
154
155 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
156 Ok(self.stopped_by_user.load(Ordering::SeqCst))
157 }
158}
159
160struct FakeSubagentHandle {
161 session_id: acp::SessionId,
162 send_task: Shared<Task<String>>,
163}
164
165impl SubagentHandle for FakeSubagentHandle {
166 fn id(&self) -> acp::SessionId {
167 self.session_id.clone()
168 }
169
170 fn num_entries(&self, _cx: &App) -> usize {
171 unimplemented!()
172 }
173
174 fn send(&self, _message: String, cx: &AsyncApp) -> Task<Result<String>> {
175 let task = self.send_task.clone();
176 cx.background_spawn(async move { Ok(task.await) })
177 }
178}
179
180#[derive(Default)]
181pub(crate) struct FakeThreadEnvironment {
182 terminal_handle: Option<Rc<FakeTerminalHandle>>,
183 subagent_handle: Option<Rc<FakeSubagentHandle>>,
184 terminal_creations: Arc<AtomicUsize>,
185}
186
187impl FakeThreadEnvironment {
188 pub(crate) fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
189 Self {
190 terminal_handle: Some(terminal_handle.into()),
191 ..self
192 }
193 }
194
195 pub(crate) fn terminal_creation_count(&self) -> usize {
196 self.terminal_creations.load(Ordering::SeqCst)
197 }
198}
199
200impl crate::ThreadEnvironment for FakeThreadEnvironment {
201 fn create_terminal(
202 &self,
203 _command: String,
204 _cwd: Option<std::path::PathBuf>,
205 _output_byte_limit: Option<u64>,
206 _cx: &mut AsyncApp,
207 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
208 self.terminal_creations.fetch_add(1, Ordering::SeqCst);
209 let handle = self
210 .terminal_handle
211 .clone()
212 .expect("Terminal handle not available on FakeThreadEnvironment");
213 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
214 }
215
216 fn create_subagent(&self, _label: String, _cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
217 Ok(self
218 .subagent_handle
219 .clone()
220 .expect("Subagent handle not available on FakeThreadEnvironment")
221 as Rc<dyn SubagentHandle>)
222 }
223}
224
225/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
226struct MultiTerminalEnvironment {
227 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
228}
229
230impl MultiTerminalEnvironment {
231 fn new() -> Self {
232 Self {
233 handles: std::cell::RefCell::new(Vec::new()),
234 }
235 }
236
237 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
238 self.handles.borrow().clone()
239 }
240}
241
242impl crate::ThreadEnvironment for MultiTerminalEnvironment {
243 fn create_terminal(
244 &self,
245 _command: String,
246 _cwd: Option<std::path::PathBuf>,
247 _output_byte_limit: Option<u64>,
248 cx: &mut AsyncApp,
249 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
250 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
251 self.handles.borrow_mut().push(handle.clone());
252 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
253 }
254
255 fn create_subagent(&self, _label: String, _cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
256 unimplemented!()
257 }
258}
259
260fn always_allow_tools(cx: &mut TestAppContext) {
261 cx.update(|cx| {
262 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
263 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
264 agent_settings::AgentSettings::override_global(settings, cx);
265 });
266}
267
268#[gpui::test]
269async fn test_echo(cx: &mut TestAppContext) {
270 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
271 let fake_model = model.as_fake();
272
273 let events = thread
274 .update(cx, |thread, cx| {
275 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
276 })
277 .unwrap();
278 cx.run_until_parked();
279 fake_model.send_last_completion_stream_text_chunk("Hello");
280 fake_model
281 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
282 fake_model.end_last_completion_stream();
283
284 let events = events.collect().await;
285 thread.update(cx, |thread, _cx| {
286 assert_eq!(
287 thread.last_received_or_pending_message().unwrap().role(),
288 Role::Assistant
289 );
290 assert_eq!(
291 thread
292 .last_received_or_pending_message()
293 .unwrap()
294 .to_markdown(),
295 "Hello\n"
296 )
297 });
298 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
299}
300
301#[gpui::test]
302async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
303 init_test(cx);
304 always_allow_tools(cx);
305
306 let fs = FakeFs::new(cx.executor());
307 let project = Project::test(fs, [], cx).await;
308
309 let environment = Rc::new(cx.update(|cx| {
310 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
311 }));
312 let handle = environment.terminal_handle.clone().unwrap();
313
314 #[allow(clippy::arc_with_non_send_sync)]
315 let tool = Arc::new(crate::TerminalTool::new(project, environment));
316 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
317
318 let task = cx.update(|cx| {
319 tool.run(
320 ToolInput::resolved(crate::TerminalToolInput {
321 command: "sleep 1000".to_string(),
322 cd: ".".to_string(),
323 timeout_ms: Some(5),
324 }),
325 event_stream,
326 cx,
327 )
328 });
329
330 let update = rx.expect_update_fields().await;
331 assert!(
332 update.content.iter().any(|blocks| {
333 blocks
334 .iter()
335 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
336 }),
337 "expected tool call update to include terminal content"
338 );
339
340 let mut task_future: Pin<Box<Fuse<Task<Result<String, String>>>>> = Box::pin(task.fuse());
341
342 let deadline = std::time::Instant::now() + Duration::from_millis(500);
343 loop {
344 if let Some(result) = task_future.as_mut().now_or_never() {
345 let result = result.expect("terminal tool task should complete");
346
347 assert!(
348 handle.was_killed(),
349 "expected terminal handle to be killed on timeout"
350 );
351 assert!(
352 result.contains("partial output"),
353 "expected result to include terminal output, got: {result}"
354 );
355 return;
356 }
357
358 if std::time::Instant::now() >= deadline {
359 panic!("timed out waiting for terminal tool task to complete");
360 }
361
362 cx.run_until_parked();
363 cx.background_executor.timer(Duration::from_millis(1)).await;
364 }
365}
366
367#[gpui::test]
368#[ignore]
369async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
370 init_test(cx);
371 always_allow_tools(cx);
372
373 let fs = FakeFs::new(cx.executor());
374 let project = Project::test(fs, [], cx).await;
375
376 let environment = Rc::new(cx.update(|cx| {
377 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
378 }));
379 let handle = environment.terminal_handle.clone().unwrap();
380
381 #[allow(clippy::arc_with_non_send_sync)]
382 let tool = Arc::new(crate::TerminalTool::new(project, environment));
383 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
384
385 let _task = cx.update(|cx| {
386 tool.run(
387 ToolInput::resolved(crate::TerminalToolInput {
388 command: "sleep 1000".to_string(),
389 cd: ".".to_string(),
390 timeout_ms: None,
391 }),
392 event_stream,
393 cx,
394 )
395 });
396
397 let update = rx.expect_update_fields().await;
398 assert!(
399 update.content.iter().any(|blocks| {
400 blocks
401 .iter()
402 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
403 }),
404 "expected tool call update to include terminal content"
405 );
406
407 cx.background_executor
408 .timer(Duration::from_millis(25))
409 .await;
410
411 assert!(
412 !handle.was_killed(),
413 "did not expect terminal handle to be killed without a timeout"
414 );
415}
416
417#[gpui::test]
418async fn test_thinking(cx: &mut TestAppContext) {
419 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
420 let fake_model = model.as_fake();
421
422 let events = thread
423 .update(cx, |thread, cx| {
424 thread.send(
425 UserMessageId::new(),
426 [indoc! {"
427 Testing:
428
429 Generate a thinking step where you just think the word 'Think',
430 and have your final answer be 'Hello'
431 "}],
432 cx,
433 )
434 })
435 .unwrap();
436 cx.run_until_parked();
437 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
438 text: "Think".to_string(),
439 signature: None,
440 });
441 fake_model.send_last_completion_stream_text_chunk("Hello");
442 fake_model
443 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
444 fake_model.end_last_completion_stream();
445
446 let events = events.collect().await;
447 thread.update(cx, |thread, _cx| {
448 assert_eq!(
449 thread.last_received_or_pending_message().unwrap().role(),
450 Role::Assistant
451 );
452 assert_eq!(
453 thread
454 .last_received_or_pending_message()
455 .unwrap()
456 .to_markdown(),
457 indoc! {"
458 <think>Think</think>
459 Hello
460 "}
461 )
462 });
463 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
464}
465
466#[gpui::test]
467async fn test_system_prompt(cx: &mut TestAppContext) {
468 let ThreadTest {
469 model,
470 thread,
471 project_context,
472 ..
473 } = setup(cx, TestModel::Fake).await;
474 let fake_model = model.as_fake();
475
476 project_context.update(cx, |project_context, _cx| {
477 project_context.shell = "test-shell".into()
478 });
479 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
480 thread
481 .update(cx, |thread, cx| {
482 thread.send(UserMessageId::new(), ["abc"], cx)
483 })
484 .unwrap();
485 cx.run_until_parked();
486 let mut pending_completions = fake_model.pending_completions();
487 assert_eq!(
488 pending_completions.len(),
489 1,
490 "unexpected pending completions: {:?}",
491 pending_completions
492 );
493
494 let pending_completion = pending_completions.pop().unwrap();
495 assert_eq!(pending_completion.messages[0].role, Role::System);
496
497 let system_message = &pending_completion.messages[0];
498 let system_prompt = system_message.content[0].to_str().unwrap();
499 assert!(
500 system_prompt.contains("test-shell"),
501 "unexpected system message: {:?}",
502 system_message
503 );
504 assert!(
505 system_prompt.contains("## Fixing Diagnostics"),
506 "unexpected system message: {:?}",
507 system_message
508 );
509}
510
511#[gpui::test]
512async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
513 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
514 let fake_model = model.as_fake();
515
516 thread
517 .update(cx, |thread, cx| {
518 thread.send(UserMessageId::new(), ["abc"], cx)
519 })
520 .unwrap();
521 cx.run_until_parked();
522 let mut pending_completions = fake_model.pending_completions();
523 assert_eq!(
524 pending_completions.len(),
525 1,
526 "unexpected pending completions: {:?}",
527 pending_completions
528 );
529
530 let pending_completion = pending_completions.pop().unwrap();
531 assert_eq!(pending_completion.messages[0].role, Role::System);
532
533 let system_message = &pending_completion.messages[0];
534 let system_prompt = system_message.content[0].to_str().unwrap();
535 assert!(
536 !system_prompt.contains("## Tool Use"),
537 "unexpected system message: {:?}",
538 system_message
539 );
540 assert!(
541 !system_prompt.contains("## Fixing Diagnostics"),
542 "unexpected system message: {:?}",
543 system_message
544 );
545}
546
547#[gpui::test]
548async fn test_prompt_caching(cx: &mut TestAppContext) {
549 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
550 let fake_model = model.as_fake();
551
552 // Send initial user message and verify it's cached
553 thread
554 .update(cx, |thread, cx| {
555 thread.send(UserMessageId::new(), ["Message 1"], cx)
556 })
557 .unwrap();
558 cx.run_until_parked();
559
560 let completion = fake_model.pending_completions().pop().unwrap();
561 assert_eq!(
562 completion.messages[1..],
563 vec![LanguageModelRequestMessage {
564 role: Role::User,
565 content: vec!["Message 1".into()],
566 cache: true,
567 reasoning_details: None,
568 }]
569 );
570 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
571 "Response to Message 1".into(),
572 ));
573 fake_model.end_last_completion_stream();
574 cx.run_until_parked();
575
576 // Send another user message and verify only the latest is cached
577 thread
578 .update(cx, |thread, cx| {
579 thread.send(UserMessageId::new(), ["Message 2"], cx)
580 })
581 .unwrap();
582 cx.run_until_parked();
583
584 let completion = fake_model.pending_completions().pop().unwrap();
585 assert_eq!(
586 completion.messages[1..],
587 vec![
588 LanguageModelRequestMessage {
589 role: Role::User,
590 content: vec!["Message 1".into()],
591 cache: false,
592 reasoning_details: None,
593 },
594 LanguageModelRequestMessage {
595 role: Role::Assistant,
596 content: vec!["Response to Message 1".into()],
597 cache: false,
598 reasoning_details: None,
599 },
600 LanguageModelRequestMessage {
601 role: Role::User,
602 content: vec!["Message 2".into()],
603 cache: true,
604 reasoning_details: None,
605 }
606 ]
607 );
608 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
609 "Response to Message 2".into(),
610 ));
611 fake_model.end_last_completion_stream();
612 cx.run_until_parked();
613
614 // Simulate a tool call and verify that the latest tool result is cached
615 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
616 thread
617 .update(cx, |thread, cx| {
618 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
619 })
620 .unwrap();
621 cx.run_until_parked();
622
623 let tool_use = LanguageModelToolUse {
624 id: "tool_1".into(),
625 name: EchoTool::NAME.into(),
626 raw_input: json!({"text": "test"}).to_string(),
627 input: json!({"text": "test"}),
628 is_input_complete: true,
629 thought_signature: None,
630 };
631 fake_model
632 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
633 fake_model.end_last_completion_stream();
634 cx.run_until_parked();
635
636 let completion = fake_model.pending_completions().pop().unwrap();
637 let tool_result = LanguageModelToolResult {
638 tool_use_id: "tool_1".into(),
639 tool_name: EchoTool::NAME.into(),
640 is_error: false,
641 content: "test".into(),
642 output: Some("test".into()),
643 };
644 assert_eq!(
645 completion.messages[1..],
646 vec![
647 LanguageModelRequestMessage {
648 role: Role::User,
649 content: vec!["Message 1".into()],
650 cache: false,
651 reasoning_details: None,
652 },
653 LanguageModelRequestMessage {
654 role: Role::Assistant,
655 content: vec!["Response to Message 1".into()],
656 cache: false,
657 reasoning_details: None,
658 },
659 LanguageModelRequestMessage {
660 role: Role::User,
661 content: vec!["Message 2".into()],
662 cache: false,
663 reasoning_details: None,
664 },
665 LanguageModelRequestMessage {
666 role: Role::Assistant,
667 content: vec!["Response to Message 2".into()],
668 cache: false,
669 reasoning_details: None,
670 },
671 LanguageModelRequestMessage {
672 role: Role::User,
673 content: vec!["Use the echo tool".into()],
674 cache: false,
675 reasoning_details: None,
676 },
677 LanguageModelRequestMessage {
678 role: Role::Assistant,
679 content: vec![MessageContent::ToolUse(tool_use)],
680 cache: false,
681 reasoning_details: None,
682 },
683 LanguageModelRequestMessage {
684 role: Role::User,
685 content: vec![MessageContent::ToolResult(tool_result)],
686 cache: true,
687 reasoning_details: None,
688 }
689 ]
690 );
691}
692
693#[gpui::test]
694#[cfg_attr(not(feature = "e2e"), ignore)]
695async fn test_basic_tool_calls(cx: &mut TestAppContext) {
696 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
697
698 // Test a tool call that's likely to complete *before* streaming stops.
699 let events = thread
700 .update(cx, |thread, cx| {
701 thread.add_tool(EchoTool);
702 thread.send(
703 UserMessageId::new(),
704 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
705 cx,
706 )
707 })
708 .unwrap()
709 .collect()
710 .await;
711 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
712
713 // Test a tool calls that's likely to complete *after* streaming stops.
714 let events = thread
715 .update(cx, |thread, cx| {
716 thread.remove_tool(&EchoTool::NAME);
717 thread.add_tool(DelayTool);
718 thread.send(
719 UserMessageId::new(),
720 [
721 "Now call the delay tool with 200ms.",
722 "When the timer goes off, then you echo the output of the tool.",
723 ],
724 cx,
725 )
726 })
727 .unwrap()
728 .collect()
729 .await;
730 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
731 thread.update(cx, |thread, _cx| {
732 assert!(
733 thread
734 .last_received_or_pending_message()
735 .unwrap()
736 .as_agent_message()
737 .unwrap()
738 .content
739 .iter()
740 .any(|content| {
741 if let AgentMessageContent::Text(text) = content {
742 text.contains("Ding")
743 } else {
744 false
745 }
746 }),
747 "{}",
748 thread.to_markdown()
749 );
750 });
751}
752
753#[gpui::test]
754#[cfg_attr(not(feature = "e2e"), ignore)]
755async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
756 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
757
758 // Test a tool call that's likely to complete *before* streaming stops.
759 let mut events = thread
760 .update(cx, |thread, cx| {
761 thread.add_tool(WordListTool);
762 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
763 })
764 .unwrap();
765
766 let mut saw_partial_tool_use = false;
767 while let Some(event) = events.next().await {
768 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
769 thread.update(cx, |thread, _cx| {
770 // Look for a tool use in the thread's last message
771 let message = thread.last_received_or_pending_message().unwrap();
772 let agent_message = message.as_agent_message().unwrap();
773 let last_content = agent_message.content.last().unwrap();
774 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
775 assert_eq!(last_tool_use.name.as_ref(), "word_list");
776 if tool_call.status == acp::ToolCallStatus::Pending {
777 if !last_tool_use.is_input_complete
778 && last_tool_use.input.get("g").is_none()
779 {
780 saw_partial_tool_use = true;
781 }
782 } else {
783 last_tool_use
784 .input
785 .get("a")
786 .expect("'a' has streamed because input is now complete");
787 last_tool_use
788 .input
789 .get("g")
790 .expect("'g' has streamed because input is now complete");
791 }
792 } else {
793 panic!("last content should be a tool use");
794 }
795 });
796 }
797 }
798
799 assert!(
800 saw_partial_tool_use,
801 "should see at least one partially streamed tool use in the history"
802 );
803}
804
805#[gpui::test]
806async fn test_tool_authorization(cx: &mut TestAppContext) {
807 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
808 let fake_model = model.as_fake();
809
810 let mut events = thread
811 .update(cx, |thread, cx| {
812 thread.add_tool(ToolRequiringPermission);
813 thread.send(UserMessageId::new(), ["abc"], cx)
814 })
815 .unwrap();
816 cx.run_until_parked();
817 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
818 LanguageModelToolUse {
819 id: "tool_id_1".into(),
820 name: ToolRequiringPermission::NAME.into(),
821 raw_input: "{}".into(),
822 input: json!({}),
823 is_input_complete: true,
824 thought_signature: None,
825 },
826 ));
827 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
828 LanguageModelToolUse {
829 id: "tool_id_2".into(),
830 name: ToolRequiringPermission::NAME.into(),
831 raw_input: "{}".into(),
832 input: json!({}),
833 is_input_complete: true,
834 thought_signature: None,
835 },
836 ));
837 fake_model.end_last_completion_stream();
838 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
839 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
840
841 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
842 tool_call_auth_1
843 .response
844 .send(acp::PermissionOptionId::new("allow"))
845 .unwrap();
846 cx.run_until_parked();
847
848 // Reject the second - send "deny" option_id directly since Deny is now a button
849 tool_call_auth_2
850 .response
851 .send(acp::PermissionOptionId::new("deny"))
852 .unwrap();
853 cx.run_until_parked();
854
855 let completion = fake_model.pending_completions().pop().unwrap();
856 let message = completion.messages.last().unwrap();
857 assert_eq!(
858 message.content,
859 vec![
860 language_model::MessageContent::ToolResult(LanguageModelToolResult {
861 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
862 tool_name: ToolRequiringPermission::NAME.into(),
863 is_error: false,
864 content: "Allowed".into(),
865 output: Some("Allowed".into())
866 }),
867 language_model::MessageContent::ToolResult(LanguageModelToolResult {
868 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
869 tool_name: ToolRequiringPermission::NAME.into(),
870 is_error: true,
871 content: "Permission to run tool denied by user".into(),
872 output: Some("Permission to run tool denied by user".into())
873 })
874 ]
875 );
876
877 // Simulate yet another tool call.
878 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
879 LanguageModelToolUse {
880 id: "tool_id_3".into(),
881 name: ToolRequiringPermission::NAME.into(),
882 raw_input: "{}".into(),
883 input: json!({}),
884 is_input_complete: true,
885 thought_signature: None,
886 },
887 ));
888 fake_model.end_last_completion_stream();
889
890 // Respond by always allowing tools - send transformed option_id
891 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
892 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
893 tool_call_auth_3
894 .response
895 .send(acp::PermissionOptionId::new(
896 "always_allow:tool_requiring_permission",
897 ))
898 .unwrap();
899 cx.run_until_parked();
900 let completion = fake_model.pending_completions().pop().unwrap();
901 let message = completion.messages.last().unwrap();
902 assert_eq!(
903 message.content,
904 vec![language_model::MessageContent::ToolResult(
905 LanguageModelToolResult {
906 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
907 tool_name: ToolRequiringPermission::NAME.into(),
908 is_error: false,
909 content: "Allowed".into(),
910 output: Some("Allowed".into())
911 }
912 )]
913 );
914
915 // Simulate a final tool call, ensuring we don't trigger authorization.
916 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
917 LanguageModelToolUse {
918 id: "tool_id_4".into(),
919 name: ToolRequiringPermission::NAME.into(),
920 raw_input: "{}".into(),
921 input: json!({}),
922 is_input_complete: true,
923 thought_signature: None,
924 },
925 ));
926 fake_model.end_last_completion_stream();
927 cx.run_until_parked();
928 let completion = fake_model.pending_completions().pop().unwrap();
929 let message = completion.messages.last().unwrap();
930 assert_eq!(
931 message.content,
932 vec![language_model::MessageContent::ToolResult(
933 LanguageModelToolResult {
934 tool_use_id: "tool_id_4".into(),
935 tool_name: ToolRequiringPermission::NAME.into(),
936 is_error: false,
937 content: "Allowed".into(),
938 output: Some("Allowed".into())
939 }
940 )]
941 );
942}
943
944#[gpui::test]
945async fn test_tool_hallucination(cx: &mut TestAppContext) {
946 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
947 let fake_model = model.as_fake();
948
949 let mut events = thread
950 .update(cx, |thread, cx| {
951 thread.send(UserMessageId::new(), ["abc"], cx)
952 })
953 .unwrap();
954 cx.run_until_parked();
955 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
956 LanguageModelToolUse {
957 id: "tool_id_1".into(),
958 name: "nonexistent_tool".into(),
959 raw_input: "{}".into(),
960 input: json!({}),
961 is_input_complete: true,
962 thought_signature: None,
963 },
964 ));
965 fake_model.end_last_completion_stream();
966
967 let tool_call = expect_tool_call(&mut events).await;
968 assert_eq!(tool_call.title, "nonexistent_tool");
969 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
970 let update = expect_tool_call_update_fields(&mut events).await;
971 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
972}
973
974async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
975 let event = events
976 .next()
977 .await
978 .expect("no tool call authorization event received")
979 .unwrap();
980 match event {
981 ThreadEvent::ToolCall(tool_call) => tool_call,
982 event => {
983 panic!("Unexpected event {event:?}");
984 }
985 }
986}
987
988async fn expect_tool_call_update_fields(
989 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
990) -> acp::ToolCallUpdate {
991 let event = events
992 .next()
993 .await
994 .expect("no tool call authorization event received")
995 .unwrap();
996 match event {
997 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
998 event => {
999 panic!("Unexpected event {event:?}");
1000 }
1001 }
1002}
1003
1004async fn next_tool_call_authorization(
1005 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1006) -> ToolCallAuthorization {
1007 loop {
1008 let event = events
1009 .next()
1010 .await
1011 .expect("no tool call authorization event received")
1012 .unwrap();
1013 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1014 let permission_kinds = tool_call_authorization
1015 .options
1016 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
1017 .map(|option| option.kind);
1018 let allow_once = tool_call_authorization
1019 .options
1020 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
1021 .map(|option| option.kind);
1022
1023 assert_eq!(
1024 permission_kinds,
1025 Some(acp::PermissionOptionKind::AllowAlways)
1026 );
1027 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
1028 return tool_call_authorization;
1029 }
1030 }
1031}
1032
1033#[test]
1034fn test_permission_options_terminal_with_pattern() {
1035 let permission_options = ToolPermissionContext::new(
1036 TerminalTool::NAME,
1037 vec!["cargo build --release".to_string()],
1038 )
1039 .build_permission_options();
1040
1041 let PermissionOptions::Dropdown(choices) = permission_options else {
1042 panic!("Expected dropdown permission options");
1043 };
1044
1045 assert_eq!(choices.len(), 3);
1046 let labels: Vec<&str> = choices
1047 .iter()
1048 .map(|choice| choice.allow.name.as_ref())
1049 .collect();
1050 assert!(labels.contains(&"Always for terminal"));
1051 assert!(labels.contains(&"Always for `cargo build` commands"));
1052 assert!(labels.contains(&"Only this time"));
1053}
1054
1055#[test]
1056fn test_permission_options_terminal_command_with_flag_second_token() {
1057 let permission_options =
1058 ToolPermissionContext::new(TerminalTool::NAME, vec!["ls -la".to_string()])
1059 .build_permission_options();
1060
1061 let PermissionOptions::Dropdown(choices) = permission_options else {
1062 panic!("Expected dropdown permission options");
1063 };
1064
1065 assert_eq!(choices.len(), 3);
1066 let labels: Vec<&str> = choices
1067 .iter()
1068 .map(|choice| choice.allow.name.as_ref())
1069 .collect();
1070 assert!(labels.contains(&"Always for terminal"));
1071 assert!(labels.contains(&"Always for `ls` commands"));
1072 assert!(labels.contains(&"Only this time"));
1073}
1074
1075#[test]
1076fn test_permission_options_terminal_single_word_command() {
1077 let permission_options =
1078 ToolPermissionContext::new(TerminalTool::NAME, vec!["whoami".to_string()])
1079 .build_permission_options();
1080
1081 let PermissionOptions::Dropdown(choices) = permission_options else {
1082 panic!("Expected dropdown permission options");
1083 };
1084
1085 assert_eq!(choices.len(), 3);
1086 let labels: Vec<&str> = choices
1087 .iter()
1088 .map(|choice| choice.allow.name.as_ref())
1089 .collect();
1090 assert!(labels.contains(&"Always for terminal"));
1091 assert!(labels.contains(&"Always for `whoami` commands"));
1092 assert!(labels.contains(&"Only this time"));
1093}
1094
1095#[test]
1096fn test_permission_options_edit_file_with_path_pattern() {
1097 let permission_options =
1098 ToolPermissionContext::new(EditFileTool::NAME, vec!["src/main.rs".to_string()])
1099 .build_permission_options();
1100
1101 let PermissionOptions::Dropdown(choices) = permission_options else {
1102 panic!("Expected dropdown permission options");
1103 };
1104
1105 let labels: Vec<&str> = choices
1106 .iter()
1107 .map(|choice| choice.allow.name.as_ref())
1108 .collect();
1109 assert!(labels.contains(&"Always for edit file"));
1110 assert!(labels.contains(&"Always for `src/`"));
1111}
1112
1113#[test]
1114fn test_permission_options_fetch_with_domain_pattern() {
1115 let permission_options =
1116 ToolPermissionContext::new(FetchTool::NAME, vec!["https://docs.rs/gpui".to_string()])
1117 .build_permission_options();
1118
1119 let PermissionOptions::Dropdown(choices) = permission_options else {
1120 panic!("Expected dropdown permission options");
1121 };
1122
1123 let labels: Vec<&str> = choices
1124 .iter()
1125 .map(|choice| choice.allow.name.as_ref())
1126 .collect();
1127 assert!(labels.contains(&"Always for fetch"));
1128 assert!(labels.contains(&"Always for `docs.rs`"));
1129}
1130
1131#[test]
1132fn test_permission_options_without_pattern() {
1133 let permission_options = ToolPermissionContext::new(
1134 TerminalTool::NAME,
1135 vec!["./deploy.sh --production".to_string()],
1136 )
1137 .build_permission_options();
1138
1139 let PermissionOptions::Dropdown(choices) = permission_options else {
1140 panic!("Expected dropdown permission options");
1141 };
1142
1143 assert_eq!(choices.len(), 2);
1144 let labels: Vec<&str> = choices
1145 .iter()
1146 .map(|choice| choice.allow.name.as_ref())
1147 .collect();
1148 assert!(labels.contains(&"Always for terminal"));
1149 assert!(labels.contains(&"Only this time"));
1150 assert!(!labels.iter().any(|label| label.contains("commands")));
1151}
1152
1153#[test]
1154fn test_permission_options_symlink_target_are_flat_once_only() {
1155 let permission_options =
1156 ToolPermissionContext::symlink_target(EditFileTool::NAME, vec!["/outside/file.txt".into()])
1157 .build_permission_options();
1158
1159 let PermissionOptions::Flat(options) = permission_options else {
1160 panic!("Expected flat permission options for symlink target authorization");
1161 };
1162
1163 assert_eq!(options.len(), 2);
1164 assert!(options.iter().any(|option| {
1165 option.option_id.0.as_ref() == "allow"
1166 && option.kind == acp::PermissionOptionKind::AllowOnce
1167 }));
1168 assert!(options.iter().any(|option| {
1169 option.option_id.0.as_ref() == "deny"
1170 && option.kind == acp::PermissionOptionKind::RejectOnce
1171 }));
1172}
1173
1174#[test]
1175fn test_permission_option_ids_for_terminal() {
1176 let permission_options = ToolPermissionContext::new(
1177 TerminalTool::NAME,
1178 vec!["cargo build --release".to_string()],
1179 )
1180 .build_permission_options();
1181
1182 let PermissionOptions::Dropdown(choices) = permission_options else {
1183 panic!("Expected dropdown permission options");
1184 };
1185
1186 let allow_ids: Vec<String> = choices
1187 .iter()
1188 .map(|choice| choice.allow.option_id.0.to_string())
1189 .collect();
1190 let deny_ids: Vec<String> = choices
1191 .iter()
1192 .map(|choice| choice.deny.option_id.0.to_string())
1193 .collect();
1194
1195 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1196 assert!(allow_ids.contains(&"allow".to_string()));
1197 assert!(
1198 allow_ids
1199 .iter()
1200 .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
1201 "Missing allow pattern option"
1202 );
1203
1204 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1205 assert!(deny_ids.contains(&"deny".to_string()));
1206 assert!(
1207 deny_ids
1208 .iter()
1209 .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
1210 "Missing deny pattern option"
1211 );
1212}
1213
1214#[gpui::test]
1215#[cfg_attr(not(feature = "e2e"), ignore)]
1216async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1217 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1218
1219 // Test concurrent tool calls with different delay times
1220 let events = thread
1221 .update(cx, |thread, cx| {
1222 thread.add_tool(DelayTool);
1223 thread.send(
1224 UserMessageId::new(),
1225 [
1226 "Call the delay tool twice in the same message.",
1227 "Once with 100ms. Once with 300ms.",
1228 "When both timers are complete, describe the outputs.",
1229 ],
1230 cx,
1231 )
1232 })
1233 .unwrap()
1234 .collect()
1235 .await;
1236
1237 let stop_reasons = stop_events(events);
1238 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1239
1240 thread.update(cx, |thread, _cx| {
1241 let last_message = thread.last_received_or_pending_message().unwrap();
1242 let agent_message = last_message.as_agent_message().unwrap();
1243 let text = agent_message
1244 .content
1245 .iter()
1246 .filter_map(|content| {
1247 if let AgentMessageContent::Text(text) = content {
1248 Some(text.as_str())
1249 } else {
1250 None
1251 }
1252 })
1253 .collect::<String>();
1254
1255 assert!(text.contains("Ding"));
1256 });
1257}
1258
1259#[gpui::test]
1260async fn test_profiles(cx: &mut TestAppContext) {
1261 let ThreadTest {
1262 model, thread, fs, ..
1263 } = setup(cx, TestModel::Fake).await;
1264 let fake_model = model.as_fake();
1265
1266 thread.update(cx, |thread, _cx| {
1267 thread.add_tool(DelayTool);
1268 thread.add_tool(EchoTool);
1269 thread.add_tool(InfiniteTool);
1270 });
1271
1272 // Override profiles and wait for settings to be loaded.
1273 fs.insert_file(
1274 paths::settings_file(),
1275 json!({
1276 "agent": {
1277 "profiles": {
1278 "test-1": {
1279 "name": "Test Profile 1",
1280 "tools": {
1281 EchoTool::NAME: true,
1282 DelayTool::NAME: true,
1283 }
1284 },
1285 "test-2": {
1286 "name": "Test Profile 2",
1287 "tools": {
1288 InfiniteTool::NAME: true,
1289 }
1290 }
1291 }
1292 }
1293 })
1294 .to_string()
1295 .into_bytes(),
1296 )
1297 .await;
1298 cx.run_until_parked();
1299
1300 // Test that test-1 profile (default) has echo and delay tools
1301 thread
1302 .update(cx, |thread, cx| {
1303 thread.set_profile(AgentProfileId("test-1".into()), cx);
1304 thread.send(UserMessageId::new(), ["test"], cx)
1305 })
1306 .unwrap();
1307 cx.run_until_parked();
1308
1309 let mut pending_completions = fake_model.pending_completions();
1310 assert_eq!(pending_completions.len(), 1);
1311 let completion = pending_completions.pop().unwrap();
1312 let tool_names: Vec<String> = completion
1313 .tools
1314 .iter()
1315 .map(|tool| tool.name.clone())
1316 .collect();
1317 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1318 fake_model.end_last_completion_stream();
1319
1320 // Switch to test-2 profile, and verify that it has only the infinite tool.
1321 thread
1322 .update(cx, |thread, cx| {
1323 thread.set_profile(AgentProfileId("test-2".into()), cx);
1324 thread.send(UserMessageId::new(), ["test2"], cx)
1325 })
1326 .unwrap();
1327 cx.run_until_parked();
1328 let mut pending_completions = fake_model.pending_completions();
1329 assert_eq!(pending_completions.len(), 1);
1330 let completion = pending_completions.pop().unwrap();
1331 let tool_names: Vec<String> = completion
1332 .tools
1333 .iter()
1334 .map(|tool| tool.name.clone())
1335 .collect();
1336 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1337}
1338
1339#[gpui::test]
1340async fn test_mcp_tools(cx: &mut TestAppContext) {
1341 let ThreadTest {
1342 model,
1343 thread,
1344 context_server_store,
1345 fs,
1346 ..
1347 } = setup(cx, TestModel::Fake).await;
1348 let fake_model = model.as_fake();
1349
1350 // Override profiles and wait for settings to be loaded.
1351 fs.insert_file(
1352 paths::settings_file(),
1353 json!({
1354 "agent": {
1355 "tool_permissions": { "default": "allow" },
1356 "profiles": {
1357 "test": {
1358 "name": "Test Profile",
1359 "enable_all_context_servers": true,
1360 "tools": {
1361 EchoTool::NAME: true,
1362 }
1363 },
1364 }
1365 }
1366 })
1367 .to_string()
1368 .into_bytes(),
1369 )
1370 .await;
1371 cx.run_until_parked();
1372 thread.update(cx, |thread, cx| {
1373 thread.set_profile(AgentProfileId("test".into()), cx)
1374 });
1375
1376 let mut mcp_tool_calls = setup_context_server(
1377 "test_server",
1378 vec![context_server::types::Tool {
1379 name: "echo".into(),
1380 description: None,
1381 input_schema: serde_json::to_value(EchoTool::input_schema(
1382 LanguageModelToolSchemaFormat::JsonSchema,
1383 ))
1384 .unwrap(),
1385 output_schema: None,
1386 annotations: None,
1387 }],
1388 &context_server_store,
1389 cx,
1390 );
1391
1392 let events = thread.update(cx, |thread, cx| {
1393 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1394 });
1395 cx.run_until_parked();
1396
1397 // Simulate the model calling the MCP tool.
1398 let completion = fake_model.pending_completions().pop().unwrap();
1399 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1400 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1401 LanguageModelToolUse {
1402 id: "tool_1".into(),
1403 name: "echo".into(),
1404 raw_input: json!({"text": "test"}).to_string(),
1405 input: json!({"text": "test"}),
1406 is_input_complete: true,
1407 thought_signature: None,
1408 },
1409 ));
1410 fake_model.end_last_completion_stream();
1411 cx.run_until_parked();
1412
1413 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1414 assert_eq!(tool_call_params.name, "echo");
1415 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1416 tool_call_response
1417 .send(context_server::types::CallToolResponse {
1418 content: vec![context_server::types::ToolResponseContent::Text {
1419 text: "test".into(),
1420 }],
1421 is_error: None,
1422 meta: None,
1423 structured_content: None,
1424 })
1425 .unwrap();
1426 cx.run_until_parked();
1427
1428 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1429 fake_model.send_last_completion_stream_text_chunk("Done!");
1430 fake_model.end_last_completion_stream();
1431 events.collect::<Vec<_>>().await;
1432
1433 // Send again after adding the echo tool, ensuring the name collision is resolved.
1434 let events = thread.update(cx, |thread, cx| {
1435 thread.add_tool(EchoTool);
1436 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1437 });
1438 cx.run_until_parked();
1439 let completion = fake_model.pending_completions().pop().unwrap();
1440 assert_eq!(
1441 tool_names_for_completion(&completion),
1442 vec!["echo", "test_server_echo"]
1443 );
1444 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1445 LanguageModelToolUse {
1446 id: "tool_2".into(),
1447 name: "test_server_echo".into(),
1448 raw_input: json!({"text": "mcp"}).to_string(),
1449 input: json!({"text": "mcp"}),
1450 is_input_complete: true,
1451 thought_signature: None,
1452 },
1453 ));
1454 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1455 LanguageModelToolUse {
1456 id: "tool_3".into(),
1457 name: "echo".into(),
1458 raw_input: json!({"text": "native"}).to_string(),
1459 input: json!({"text": "native"}),
1460 is_input_complete: true,
1461 thought_signature: None,
1462 },
1463 ));
1464 fake_model.end_last_completion_stream();
1465 cx.run_until_parked();
1466
1467 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1468 assert_eq!(tool_call_params.name, "echo");
1469 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1470 tool_call_response
1471 .send(context_server::types::CallToolResponse {
1472 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1473 is_error: None,
1474 meta: None,
1475 structured_content: None,
1476 })
1477 .unwrap();
1478 cx.run_until_parked();
1479
1480 // Ensure the tool results were inserted with the correct names.
1481 let completion = fake_model.pending_completions().pop().unwrap();
1482 assert_eq!(
1483 completion.messages.last().unwrap().content,
1484 vec![
1485 MessageContent::ToolResult(LanguageModelToolResult {
1486 tool_use_id: "tool_3".into(),
1487 tool_name: "echo".into(),
1488 is_error: false,
1489 content: "native".into(),
1490 output: Some("native".into()),
1491 },),
1492 MessageContent::ToolResult(LanguageModelToolResult {
1493 tool_use_id: "tool_2".into(),
1494 tool_name: "test_server_echo".into(),
1495 is_error: false,
1496 content: "mcp".into(),
1497 output: Some("mcp".into()),
1498 },),
1499 ]
1500 );
1501 fake_model.end_last_completion_stream();
1502 events.collect::<Vec<_>>().await;
1503}
1504
1505#[gpui::test]
1506async fn test_mcp_tool_result_displayed_when_server_disconnected(cx: &mut TestAppContext) {
1507 let ThreadTest {
1508 model,
1509 thread,
1510 context_server_store,
1511 fs,
1512 ..
1513 } = setup(cx, TestModel::Fake).await;
1514 let fake_model = model.as_fake();
1515
1516 // Setup settings to allow MCP tools
1517 fs.insert_file(
1518 paths::settings_file(),
1519 json!({
1520 "agent": {
1521 "always_allow_tool_actions": true,
1522 "profiles": {
1523 "test": {
1524 "name": "Test Profile",
1525 "enable_all_context_servers": true,
1526 "tools": {}
1527 },
1528 }
1529 }
1530 })
1531 .to_string()
1532 .into_bytes(),
1533 )
1534 .await;
1535 cx.run_until_parked();
1536 thread.update(cx, |thread, cx| {
1537 thread.set_profile(AgentProfileId("test".into()), cx)
1538 });
1539
1540 // Setup a context server with a tool
1541 let mut mcp_tool_calls = setup_context_server(
1542 "github_server",
1543 vec![context_server::types::Tool {
1544 name: "issue_read".into(),
1545 description: Some("Read a GitHub issue".into()),
1546 input_schema: json!({
1547 "type": "object",
1548 "properties": {
1549 "issue_url": { "type": "string" }
1550 }
1551 }),
1552 output_schema: None,
1553 annotations: None,
1554 }],
1555 &context_server_store,
1556 cx,
1557 );
1558
1559 // Send a message and have the model call the MCP tool
1560 let events = thread.update(cx, |thread, cx| {
1561 thread
1562 .send(UserMessageId::new(), ["Read issue #47404"], cx)
1563 .unwrap()
1564 });
1565 cx.run_until_parked();
1566
1567 // Verify the MCP tool is available to the model
1568 let completion = fake_model.pending_completions().pop().unwrap();
1569 assert_eq!(
1570 tool_names_for_completion(&completion),
1571 vec!["issue_read"],
1572 "MCP tool should be available"
1573 );
1574
1575 // Simulate the model calling the MCP tool
1576 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1577 LanguageModelToolUse {
1578 id: "tool_1".into(),
1579 name: "issue_read".into(),
1580 raw_input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"})
1581 .to_string(),
1582 input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"}),
1583 is_input_complete: true,
1584 thought_signature: None,
1585 },
1586 ));
1587 fake_model.end_last_completion_stream();
1588 cx.run_until_parked();
1589
1590 // The MCP server receives the tool call and responds with content
1591 let expected_tool_output = "Issue #47404: Tool call results are cleared upon app close";
1592 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1593 assert_eq!(tool_call_params.name, "issue_read");
1594 tool_call_response
1595 .send(context_server::types::CallToolResponse {
1596 content: vec![context_server::types::ToolResponseContent::Text {
1597 text: expected_tool_output.into(),
1598 }],
1599 is_error: None,
1600 meta: None,
1601 structured_content: None,
1602 })
1603 .unwrap();
1604 cx.run_until_parked();
1605
1606 // After tool completes, the model continues with a new completion request
1607 // that includes the tool results. We need to respond to this.
1608 let _completion = fake_model.pending_completions().pop().unwrap();
1609 fake_model.send_last_completion_stream_text_chunk("I found the issue!");
1610 fake_model
1611 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1612 fake_model.end_last_completion_stream();
1613 events.collect::<Vec<_>>().await;
1614
1615 // Verify the tool result is stored in the thread by checking the markdown output.
1616 // The tool result is in the first assistant message (not the last one, which is
1617 // the model's response after the tool completed).
1618 thread.update(cx, |thread, _cx| {
1619 let markdown = thread.to_markdown();
1620 assert!(
1621 markdown.contains("**Tool Result**: issue_read"),
1622 "Thread should contain tool result header"
1623 );
1624 assert!(
1625 markdown.contains(expected_tool_output),
1626 "Thread should contain tool output: {}",
1627 expected_tool_output
1628 );
1629 });
1630
1631 // Simulate app restart: disconnect the MCP server.
1632 // After restart, the MCP server won't be connected yet when the thread is replayed.
1633 context_server_store.update(cx, |store, cx| {
1634 let _ = store.stop_server(&ContextServerId("github_server".into()), cx);
1635 });
1636 cx.run_until_parked();
1637
1638 // Replay the thread (this is what happens when loading a saved thread)
1639 let mut replay_events = thread.update(cx, |thread, cx| thread.replay(cx));
1640
1641 let mut found_tool_call = None;
1642 let mut found_tool_call_update_with_output = None;
1643
1644 while let Some(event) = replay_events.next().await {
1645 let event = event.unwrap();
1646 match &event {
1647 ThreadEvent::ToolCall(tc) if tc.tool_call_id.to_string() == "tool_1" => {
1648 found_tool_call = Some(tc.clone());
1649 }
1650 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update))
1651 if update.tool_call_id.to_string() == "tool_1" =>
1652 {
1653 if update.fields.raw_output.is_some() {
1654 found_tool_call_update_with_output = Some(update.clone());
1655 }
1656 }
1657 _ => {}
1658 }
1659 }
1660
1661 // The tool call should be found
1662 assert!(
1663 found_tool_call.is_some(),
1664 "Tool call should be emitted during replay"
1665 );
1666
1667 assert!(
1668 found_tool_call_update_with_output.is_some(),
1669 "ToolCallUpdate with raw_output should be emitted even when MCP server is disconnected."
1670 );
1671
1672 let update = found_tool_call_update_with_output.unwrap();
1673 assert_eq!(
1674 update.fields.raw_output,
1675 Some(expected_tool_output.into()),
1676 "raw_output should contain the saved tool result"
1677 );
1678
1679 // Also verify the status is correct (completed, not failed)
1680 assert_eq!(
1681 update.fields.status,
1682 Some(acp::ToolCallStatus::Completed),
1683 "Tool call status should reflect the original completion status"
1684 );
1685}
1686
1687#[gpui::test]
1688async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1689 let ThreadTest {
1690 model,
1691 thread,
1692 context_server_store,
1693 fs,
1694 ..
1695 } = setup(cx, TestModel::Fake).await;
1696 let fake_model = model.as_fake();
1697
1698 // Set up a profile with all tools enabled
1699 fs.insert_file(
1700 paths::settings_file(),
1701 json!({
1702 "agent": {
1703 "profiles": {
1704 "test": {
1705 "name": "Test Profile",
1706 "enable_all_context_servers": true,
1707 "tools": {
1708 EchoTool::NAME: true,
1709 DelayTool::NAME: true,
1710 WordListTool::NAME: true,
1711 ToolRequiringPermission::NAME: true,
1712 InfiniteTool::NAME: true,
1713 }
1714 },
1715 }
1716 }
1717 })
1718 .to_string()
1719 .into_bytes(),
1720 )
1721 .await;
1722 cx.run_until_parked();
1723
1724 thread.update(cx, |thread, cx| {
1725 thread.set_profile(AgentProfileId("test".into()), cx);
1726 thread.add_tool(EchoTool);
1727 thread.add_tool(DelayTool);
1728 thread.add_tool(WordListTool);
1729 thread.add_tool(ToolRequiringPermission);
1730 thread.add_tool(InfiniteTool);
1731 });
1732
1733 // Set up multiple context servers with some overlapping tool names
1734 let _server1_calls = setup_context_server(
1735 "xxx",
1736 vec![
1737 context_server::types::Tool {
1738 name: "echo".into(), // Conflicts with native EchoTool
1739 description: None,
1740 input_schema: serde_json::to_value(EchoTool::input_schema(
1741 LanguageModelToolSchemaFormat::JsonSchema,
1742 ))
1743 .unwrap(),
1744 output_schema: None,
1745 annotations: None,
1746 },
1747 context_server::types::Tool {
1748 name: "unique_tool_1".into(),
1749 description: None,
1750 input_schema: json!({"type": "object", "properties": {}}),
1751 output_schema: None,
1752 annotations: None,
1753 },
1754 ],
1755 &context_server_store,
1756 cx,
1757 );
1758
1759 let _server2_calls = setup_context_server(
1760 "yyy",
1761 vec![
1762 context_server::types::Tool {
1763 name: "echo".into(), // Also conflicts with native EchoTool
1764 description: None,
1765 input_schema: serde_json::to_value(EchoTool::input_schema(
1766 LanguageModelToolSchemaFormat::JsonSchema,
1767 ))
1768 .unwrap(),
1769 output_schema: None,
1770 annotations: None,
1771 },
1772 context_server::types::Tool {
1773 name: "unique_tool_2".into(),
1774 description: None,
1775 input_schema: json!({"type": "object", "properties": {}}),
1776 output_schema: None,
1777 annotations: None,
1778 },
1779 context_server::types::Tool {
1780 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1781 description: None,
1782 input_schema: json!({"type": "object", "properties": {}}),
1783 output_schema: None,
1784 annotations: None,
1785 },
1786 context_server::types::Tool {
1787 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1788 description: None,
1789 input_schema: json!({"type": "object", "properties": {}}),
1790 output_schema: None,
1791 annotations: None,
1792 },
1793 ],
1794 &context_server_store,
1795 cx,
1796 );
1797 let _server3_calls = setup_context_server(
1798 "zzz",
1799 vec![
1800 context_server::types::Tool {
1801 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1802 description: None,
1803 input_schema: json!({"type": "object", "properties": {}}),
1804 output_schema: None,
1805 annotations: None,
1806 },
1807 context_server::types::Tool {
1808 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1809 description: None,
1810 input_schema: json!({"type": "object", "properties": {}}),
1811 output_schema: None,
1812 annotations: None,
1813 },
1814 context_server::types::Tool {
1815 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1816 description: None,
1817 input_schema: json!({"type": "object", "properties": {}}),
1818 output_schema: None,
1819 annotations: None,
1820 },
1821 ],
1822 &context_server_store,
1823 cx,
1824 );
1825
1826 // Server with spaces in name - tests snake_case conversion for API compatibility
1827 let _server4_calls = setup_context_server(
1828 "Azure DevOps",
1829 vec![context_server::types::Tool {
1830 name: "echo".into(), // Also conflicts - will be disambiguated as azure_dev_ops_echo
1831 description: None,
1832 input_schema: serde_json::to_value(EchoTool::input_schema(
1833 LanguageModelToolSchemaFormat::JsonSchema,
1834 ))
1835 .unwrap(),
1836 output_schema: None,
1837 annotations: None,
1838 }],
1839 &context_server_store,
1840 cx,
1841 );
1842
1843 thread
1844 .update(cx, |thread, cx| {
1845 thread.send(UserMessageId::new(), ["Go"], cx)
1846 })
1847 .unwrap();
1848 cx.run_until_parked();
1849 let completion = fake_model.pending_completions().pop().unwrap();
1850 assert_eq!(
1851 tool_names_for_completion(&completion),
1852 vec![
1853 "azure_dev_ops_echo",
1854 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1855 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1856 "delay",
1857 "echo",
1858 "infinite",
1859 "tool_requiring_permission",
1860 "unique_tool_1",
1861 "unique_tool_2",
1862 "word_list",
1863 "xxx_echo",
1864 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1865 "yyy_echo",
1866 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1867 ]
1868 );
1869}
1870
1871#[gpui::test]
1872#[cfg_attr(not(feature = "e2e"), ignore)]
1873async fn test_cancellation(cx: &mut TestAppContext) {
1874 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1875
1876 let mut events = thread
1877 .update(cx, |thread, cx| {
1878 thread.add_tool(InfiniteTool);
1879 thread.add_tool(EchoTool);
1880 thread.send(
1881 UserMessageId::new(),
1882 ["Call the echo tool, then call the infinite tool, then explain their output"],
1883 cx,
1884 )
1885 })
1886 .unwrap();
1887
1888 // Wait until both tools are called.
1889 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1890 let mut echo_id = None;
1891 let mut echo_completed = false;
1892 while let Some(event) = events.next().await {
1893 match event.unwrap() {
1894 ThreadEvent::ToolCall(tool_call) => {
1895 assert_eq!(tool_call.title, expected_tools.remove(0));
1896 if tool_call.title == "Echo" {
1897 echo_id = Some(tool_call.tool_call_id);
1898 }
1899 }
1900 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1901 acp::ToolCallUpdate {
1902 tool_call_id,
1903 fields:
1904 acp::ToolCallUpdateFields {
1905 status: Some(acp::ToolCallStatus::Completed),
1906 ..
1907 },
1908 ..
1909 },
1910 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1911 echo_completed = true;
1912 }
1913 _ => {}
1914 }
1915
1916 if expected_tools.is_empty() && echo_completed {
1917 break;
1918 }
1919 }
1920
1921 // Cancel the current send and ensure that the event stream is closed, even
1922 // if one of the tools is still running.
1923 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1924 let events = events.collect::<Vec<_>>().await;
1925 let last_event = events.last();
1926 assert!(
1927 matches!(
1928 last_event,
1929 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1930 ),
1931 "unexpected event {last_event:?}"
1932 );
1933
1934 // Ensure we can still send a new message after cancellation.
1935 let events = thread
1936 .update(cx, |thread, cx| {
1937 thread.send(
1938 UserMessageId::new(),
1939 ["Testing: reply with 'Hello' then stop."],
1940 cx,
1941 )
1942 })
1943 .unwrap()
1944 .collect::<Vec<_>>()
1945 .await;
1946 thread.update(cx, |thread, _cx| {
1947 let message = thread.last_received_or_pending_message().unwrap();
1948 let agent_message = message.as_agent_message().unwrap();
1949 assert_eq!(
1950 agent_message.content,
1951 vec![AgentMessageContent::Text("Hello".to_string())]
1952 );
1953 });
1954 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1955}
1956
1957#[gpui::test]
1958async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1959 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1960 always_allow_tools(cx);
1961 let fake_model = model.as_fake();
1962
1963 let environment = Rc::new(cx.update(|cx| {
1964 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1965 }));
1966 let handle = environment.terminal_handle.clone().unwrap();
1967
1968 let mut events = thread
1969 .update(cx, |thread, cx| {
1970 thread.add_tool(crate::TerminalTool::new(
1971 thread.project().clone(),
1972 environment,
1973 ));
1974 thread.send(UserMessageId::new(), ["run a command"], cx)
1975 })
1976 .unwrap();
1977
1978 cx.run_until_parked();
1979
1980 // Simulate the model calling the terminal tool
1981 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1982 LanguageModelToolUse {
1983 id: "terminal_tool_1".into(),
1984 name: TerminalTool::NAME.into(),
1985 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1986 input: json!({"command": "sleep 1000", "cd": "."}),
1987 is_input_complete: true,
1988 thought_signature: None,
1989 },
1990 ));
1991 fake_model.end_last_completion_stream();
1992
1993 // Wait for the terminal tool to start running
1994 wait_for_terminal_tool_started(&mut events, cx).await;
1995
1996 // Cancel the thread while the terminal is running
1997 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1998
1999 // Collect remaining events, driving the executor to let cancellation complete
2000 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2001
2002 // Verify the terminal was killed
2003 assert!(
2004 handle.was_killed(),
2005 "expected terminal handle to be killed on cancellation"
2006 );
2007
2008 // Verify we got a cancellation stop event
2009 assert_eq!(
2010 stop_events(remaining_events),
2011 vec![acp::StopReason::Cancelled],
2012 );
2013
2014 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
2015 thread.update(cx, |thread, _cx| {
2016 let message = thread.last_received_or_pending_message().unwrap();
2017 let agent_message = message.as_agent_message().unwrap();
2018
2019 let tool_use = agent_message
2020 .content
2021 .iter()
2022 .find_map(|content| match content {
2023 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2024 _ => None,
2025 })
2026 .expect("expected tool use in agent message");
2027
2028 let tool_result = agent_message
2029 .tool_results
2030 .get(&tool_use.id)
2031 .expect("expected tool result");
2032
2033 let result_text = match &tool_result.content {
2034 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2035 _ => panic!("expected text content in tool result"),
2036 };
2037
2038 // "partial output" comes from FakeTerminalHandle's output field
2039 assert!(
2040 result_text.contains("partial output"),
2041 "expected tool result to contain terminal output, got: {result_text}"
2042 );
2043 // Match the actual format from process_content in terminal_tool.rs
2044 assert!(
2045 result_text.contains("The user stopped this command"),
2046 "expected tool result to indicate user stopped, got: {result_text}"
2047 );
2048 });
2049
2050 // Verify we can send a new message after cancellation
2051 verify_thread_recovery(&thread, &fake_model, cx).await;
2052}
2053
2054#[gpui::test]
2055async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
2056 // This test verifies that tools which properly handle cancellation via
2057 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
2058 // to cancellation and report that they were cancelled.
2059 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2060 always_allow_tools(cx);
2061 let fake_model = model.as_fake();
2062
2063 let (tool, was_cancelled) = CancellationAwareTool::new();
2064
2065 let mut events = thread
2066 .update(cx, |thread, cx| {
2067 thread.add_tool(tool);
2068 thread.send(
2069 UserMessageId::new(),
2070 ["call the cancellation aware tool"],
2071 cx,
2072 )
2073 })
2074 .unwrap();
2075
2076 cx.run_until_parked();
2077
2078 // Simulate the model calling the cancellation-aware tool
2079 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2080 LanguageModelToolUse {
2081 id: "cancellation_aware_1".into(),
2082 name: "cancellation_aware".into(),
2083 raw_input: r#"{}"#.into(),
2084 input: json!({}),
2085 is_input_complete: true,
2086 thought_signature: None,
2087 },
2088 ));
2089 fake_model.end_last_completion_stream();
2090
2091 cx.run_until_parked();
2092
2093 // Wait for the tool call to be reported
2094 let mut tool_started = false;
2095 let deadline = cx.executor().num_cpus() * 100;
2096 for _ in 0..deadline {
2097 cx.run_until_parked();
2098
2099 while let Some(Some(event)) = events.next().now_or_never() {
2100 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
2101 if tool_call.title == "Cancellation Aware Tool" {
2102 tool_started = true;
2103 break;
2104 }
2105 }
2106 }
2107
2108 if tool_started {
2109 break;
2110 }
2111
2112 cx.background_executor
2113 .timer(Duration::from_millis(10))
2114 .await;
2115 }
2116 assert!(tool_started, "expected cancellation aware tool to start");
2117
2118 // Cancel the thread and wait for it to complete
2119 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
2120
2121 // The cancel task should complete promptly because the tool handles cancellation
2122 let timeout = cx.background_executor.timer(Duration::from_secs(5));
2123 futures::select! {
2124 _ = cancel_task.fuse() => {}
2125 _ = timeout.fuse() => {
2126 panic!("cancel task timed out - tool did not respond to cancellation");
2127 }
2128 }
2129
2130 // Verify the tool detected cancellation via its flag
2131 assert!(
2132 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
2133 "tool should have detected cancellation via event_stream.cancelled_by_user()"
2134 );
2135
2136 // Collect remaining events
2137 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2138
2139 // Verify we got a cancellation stop event
2140 assert_eq!(
2141 stop_events(remaining_events),
2142 vec![acp::StopReason::Cancelled],
2143 );
2144
2145 // Verify we can send a new message after cancellation
2146 verify_thread_recovery(&thread, &fake_model, cx).await;
2147}
2148
2149/// Helper to verify thread can recover after cancellation by sending a simple message.
2150async fn verify_thread_recovery(
2151 thread: &Entity<Thread>,
2152 fake_model: &FakeLanguageModel,
2153 cx: &mut TestAppContext,
2154) {
2155 let events = thread
2156 .update(cx, |thread, cx| {
2157 thread.send(
2158 UserMessageId::new(),
2159 ["Testing: reply with 'Hello' then stop."],
2160 cx,
2161 )
2162 })
2163 .unwrap();
2164 cx.run_until_parked();
2165 fake_model.send_last_completion_stream_text_chunk("Hello");
2166 fake_model
2167 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2168 fake_model.end_last_completion_stream();
2169
2170 let events = events.collect::<Vec<_>>().await;
2171 thread.update(cx, |thread, _cx| {
2172 let message = thread.last_received_or_pending_message().unwrap();
2173 let agent_message = message.as_agent_message().unwrap();
2174 assert_eq!(
2175 agent_message.content,
2176 vec![AgentMessageContent::Text("Hello".to_string())]
2177 );
2178 });
2179 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
2180}
2181
2182/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
2183async fn wait_for_terminal_tool_started(
2184 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2185 cx: &mut TestAppContext,
2186) {
2187 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
2188 for _ in 0..deadline {
2189 cx.run_until_parked();
2190
2191 while let Some(Some(event)) = events.next().now_or_never() {
2192 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2193 update,
2194 ))) = &event
2195 {
2196 if update.fields.content.as_ref().is_some_and(|content| {
2197 content
2198 .iter()
2199 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2200 }) {
2201 return;
2202 }
2203 }
2204 }
2205
2206 cx.background_executor
2207 .timer(Duration::from_millis(10))
2208 .await;
2209 }
2210 panic!("terminal tool did not start within the expected time");
2211}
2212
2213/// Collects events until a Stop event is received, driving the executor to completion.
2214async fn collect_events_until_stop(
2215 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2216 cx: &mut TestAppContext,
2217) -> Vec<Result<ThreadEvent>> {
2218 let mut collected = Vec::new();
2219 let deadline = cx.executor().num_cpus() * 200;
2220
2221 for _ in 0..deadline {
2222 cx.executor().advance_clock(Duration::from_millis(10));
2223 cx.run_until_parked();
2224
2225 while let Some(Some(event)) = events.next().now_or_never() {
2226 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
2227 collected.push(event);
2228 if is_stop {
2229 return collected;
2230 }
2231 }
2232 }
2233 panic!(
2234 "did not receive Stop event within the expected time; collected {} events",
2235 collected.len()
2236 );
2237}
2238
2239#[gpui::test]
2240async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
2241 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2242 always_allow_tools(cx);
2243 let fake_model = model.as_fake();
2244
2245 let environment = Rc::new(cx.update(|cx| {
2246 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2247 }));
2248 let handle = environment.terminal_handle.clone().unwrap();
2249
2250 let message_id = UserMessageId::new();
2251 let mut events = thread
2252 .update(cx, |thread, cx| {
2253 thread.add_tool(crate::TerminalTool::new(
2254 thread.project().clone(),
2255 environment,
2256 ));
2257 thread.send(message_id.clone(), ["run a command"], cx)
2258 })
2259 .unwrap();
2260
2261 cx.run_until_parked();
2262
2263 // Simulate the model calling the terminal tool
2264 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2265 LanguageModelToolUse {
2266 id: "terminal_tool_1".into(),
2267 name: TerminalTool::NAME.into(),
2268 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2269 input: json!({"command": "sleep 1000", "cd": "."}),
2270 is_input_complete: true,
2271 thought_signature: None,
2272 },
2273 ));
2274 fake_model.end_last_completion_stream();
2275
2276 // Wait for the terminal tool to start running
2277 wait_for_terminal_tool_started(&mut events, cx).await;
2278
2279 // Truncate the thread while the terminal is running
2280 thread
2281 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2282 .unwrap();
2283
2284 // Drive the executor to let cancellation complete
2285 let _ = collect_events_until_stop(&mut events, cx).await;
2286
2287 // Verify the terminal was killed
2288 assert!(
2289 handle.was_killed(),
2290 "expected terminal handle to be killed on truncate"
2291 );
2292
2293 // Verify the thread is empty after truncation
2294 thread.update(cx, |thread, _cx| {
2295 assert_eq!(
2296 thread.to_markdown(),
2297 "",
2298 "expected thread to be empty after truncating the only message"
2299 );
2300 });
2301
2302 // Verify we can send a new message after truncation
2303 verify_thread_recovery(&thread, &fake_model, cx).await;
2304}
2305
2306#[gpui::test]
2307async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2308 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2309 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2310 always_allow_tools(cx);
2311 let fake_model = model.as_fake();
2312
2313 let environment = Rc::new(MultiTerminalEnvironment::new());
2314
2315 let mut events = thread
2316 .update(cx, |thread, cx| {
2317 thread.add_tool(crate::TerminalTool::new(
2318 thread.project().clone(),
2319 environment.clone(),
2320 ));
2321 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2322 })
2323 .unwrap();
2324
2325 cx.run_until_parked();
2326
2327 // Simulate the model calling two terminal tools
2328 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2329 LanguageModelToolUse {
2330 id: "terminal_tool_1".into(),
2331 name: TerminalTool::NAME.into(),
2332 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2333 input: json!({"command": "sleep 1000", "cd": "."}),
2334 is_input_complete: true,
2335 thought_signature: None,
2336 },
2337 ));
2338 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2339 LanguageModelToolUse {
2340 id: "terminal_tool_2".into(),
2341 name: TerminalTool::NAME.into(),
2342 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2343 input: json!({"command": "sleep 2000", "cd": "."}),
2344 is_input_complete: true,
2345 thought_signature: None,
2346 },
2347 ));
2348 fake_model.end_last_completion_stream();
2349
2350 // Wait for both terminal tools to start by counting terminal content updates
2351 let mut terminals_started = 0;
2352 let deadline = cx.executor().num_cpus() * 100;
2353 for _ in 0..deadline {
2354 cx.run_until_parked();
2355
2356 while let Some(Some(event)) = events.next().now_or_never() {
2357 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2358 update,
2359 ))) = &event
2360 {
2361 if update.fields.content.as_ref().is_some_and(|content| {
2362 content
2363 .iter()
2364 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2365 }) {
2366 terminals_started += 1;
2367 if terminals_started >= 2 {
2368 break;
2369 }
2370 }
2371 }
2372 }
2373 if terminals_started >= 2 {
2374 break;
2375 }
2376
2377 cx.background_executor
2378 .timer(Duration::from_millis(10))
2379 .await;
2380 }
2381 assert!(
2382 terminals_started >= 2,
2383 "expected 2 terminal tools to start, got {terminals_started}"
2384 );
2385
2386 // Cancel the thread while both terminals are running
2387 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2388
2389 // Collect remaining events
2390 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2391
2392 // Verify both terminal handles were killed
2393 let handles = environment.handles();
2394 assert_eq!(
2395 handles.len(),
2396 2,
2397 "expected 2 terminal handles to be created"
2398 );
2399 assert!(
2400 handles[0].was_killed(),
2401 "expected first terminal handle to be killed on cancellation"
2402 );
2403 assert!(
2404 handles[1].was_killed(),
2405 "expected second terminal handle to be killed on cancellation"
2406 );
2407
2408 // Verify we got a cancellation stop event
2409 assert_eq!(
2410 stop_events(remaining_events),
2411 vec![acp::StopReason::Cancelled],
2412 );
2413}
2414
2415#[gpui::test]
2416async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2417 // Tests that clicking the stop button on the terminal card (as opposed to the main
2418 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2419 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2420 always_allow_tools(cx);
2421 let fake_model = model.as_fake();
2422
2423 let environment = Rc::new(cx.update(|cx| {
2424 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2425 }));
2426 let handle = environment.terminal_handle.clone().unwrap();
2427
2428 let mut events = thread
2429 .update(cx, |thread, cx| {
2430 thread.add_tool(crate::TerminalTool::new(
2431 thread.project().clone(),
2432 environment,
2433 ));
2434 thread.send(UserMessageId::new(), ["run a command"], cx)
2435 })
2436 .unwrap();
2437
2438 cx.run_until_parked();
2439
2440 // Simulate the model calling the terminal tool
2441 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2442 LanguageModelToolUse {
2443 id: "terminal_tool_1".into(),
2444 name: TerminalTool::NAME.into(),
2445 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2446 input: json!({"command": "sleep 1000", "cd": "."}),
2447 is_input_complete: true,
2448 thought_signature: None,
2449 },
2450 ));
2451 fake_model.end_last_completion_stream();
2452
2453 // Wait for the terminal tool to start running
2454 wait_for_terminal_tool_started(&mut events, cx).await;
2455
2456 // Simulate user clicking stop on the terminal card itself.
2457 // This sets the flag and signals exit (simulating what the real UI would do).
2458 handle.set_stopped_by_user(true);
2459 handle.killed.store(true, Ordering::SeqCst);
2460 handle.signal_exit();
2461
2462 // Wait for the tool to complete
2463 cx.run_until_parked();
2464
2465 // The thread continues after tool completion - simulate the model ending its turn
2466 fake_model
2467 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2468 fake_model.end_last_completion_stream();
2469
2470 // Collect remaining events
2471 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2472
2473 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2474 assert_eq!(
2475 stop_events(remaining_events),
2476 vec![acp::StopReason::EndTurn],
2477 );
2478
2479 // Verify the tool result indicates user stopped
2480 thread.update(cx, |thread, _cx| {
2481 let message = thread.last_received_or_pending_message().unwrap();
2482 let agent_message = message.as_agent_message().unwrap();
2483
2484 let tool_use = agent_message
2485 .content
2486 .iter()
2487 .find_map(|content| match content {
2488 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2489 _ => None,
2490 })
2491 .expect("expected tool use in agent message");
2492
2493 let tool_result = agent_message
2494 .tool_results
2495 .get(&tool_use.id)
2496 .expect("expected tool result");
2497
2498 let result_text = match &tool_result.content {
2499 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2500 _ => panic!("expected text content in tool result"),
2501 };
2502
2503 assert!(
2504 result_text.contains("The user stopped this command"),
2505 "expected tool result to indicate user stopped, got: {result_text}"
2506 );
2507 });
2508}
2509
2510#[gpui::test]
2511async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2512 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2513 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2514 always_allow_tools(cx);
2515 let fake_model = model.as_fake();
2516
2517 let environment = Rc::new(cx.update(|cx| {
2518 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2519 }));
2520 let handle = environment.terminal_handle.clone().unwrap();
2521
2522 let mut events = thread
2523 .update(cx, |thread, cx| {
2524 thread.add_tool(crate::TerminalTool::new(
2525 thread.project().clone(),
2526 environment,
2527 ));
2528 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2529 })
2530 .unwrap();
2531
2532 cx.run_until_parked();
2533
2534 // Simulate the model calling the terminal tool with a short timeout
2535 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2536 LanguageModelToolUse {
2537 id: "terminal_tool_1".into(),
2538 name: TerminalTool::NAME.into(),
2539 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2540 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2541 is_input_complete: true,
2542 thought_signature: None,
2543 },
2544 ));
2545 fake_model.end_last_completion_stream();
2546
2547 // Wait for the terminal tool to start running
2548 wait_for_terminal_tool_started(&mut events, cx).await;
2549
2550 // Advance clock past the timeout
2551 cx.executor().advance_clock(Duration::from_millis(200));
2552 cx.run_until_parked();
2553
2554 // The thread continues after tool completion - simulate the model ending its turn
2555 fake_model
2556 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2557 fake_model.end_last_completion_stream();
2558
2559 // Collect remaining events
2560 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2561
2562 // Verify the terminal was killed due to timeout
2563 assert!(
2564 handle.was_killed(),
2565 "expected terminal handle to be killed on timeout"
2566 );
2567
2568 // Verify we got an EndTurn (the tool completed, just with timeout)
2569 assert_eq!(
2570 stop_events(remaining_events),
2571 vec![acp::StopReason::EndTurn],
2572 );
2573
2574 // Verify the tool result indicates timeout, not user stopped
2575 thread.update(cx, |thread, _cx| {
2576 let message = thread.last_received_or_pending_message().unwrap();
2577 let agent_message = message.as_agent_message().unwrap();
2578
2579 let tool_use = agent_message
2580 .content
2581 .iter()
2582 .find_map(|content| match content {
2583 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2584 _ => None,
2585 })
2586 .expect("expected tool use in agent message");
2587
2588 let tool_result = agent_message
2589 .tool_results
2590 .get(&tool_use.id)
2591 .expect("expected tool result");
2592
2593 let result_text = match &tool_result.content {
2594 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2595 _ => panic!("expected text content in tool result"),
2596 };
2597
2598 assert!(
2599 result_text.contains("timed out"),
2600 "expected tool result to indicate timeout, got: {result_text}"
2601 );
2602 assert!(
2603 !result_text.contains("The user stopped"),
2604 "tool result should not mention user stopped when it timed out, got: {result_text}"
2605 );
2606 });
2607}
2608
2609#[gpui::test]
2610async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2611 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2612 let fake_model = model.as_fake();
2613
2614 let events_1 = thread
2615 .update(cx, |thread, cx| {
2616 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2617 })
2618 .unwrap();
2619 cx.run_until_parked();
2620 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2621 cx.run_until_parked();
2622
2623 let events_2 = thread
2624 .update(cx, |thread, cx| {
2625 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2626 })
2627 .unwrap();
2628 cx.run_until_parked();
2629 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2630 fake_model
2631 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2632 fake_model.end_last_completion_stream();
2633
2634 let events_1 = events_1.collect::<Vec<_>>().await;
2635 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2636 let events_2 = events_2.collect::<Vec<_>>().await;
2637 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2638}
2639
2640#[gpui::test]
2641async fn test_retry_cancelled_promptly_on_new_send(cx: &mut TestAppContext) {
2642 // Regression test: when a completion fails with a retryable error (e.g. upstream 500),
2643 // the retry loop waits on a timer. If the user switches models and sends a new message
2644 // during that delay, the old turn should exit immediately instead of retrying with the
2645 // stale model.
2646 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2647 let model_a = model.as_fake();
2648
2649 // Start a turn with model_a.
2650 let events_1 = thread
2651 .update(cx, |thread, cx| {
2652 thread.send(UserMessageId::new(), ["Hello"], cx)
2653 })
2654 .unwrap();
2655 cx.run_until_parked();
2656 assert_eq!(model_a.completion_count(), 1);
2657
2658 // Model returns a retryable upstream 500. The turn enters the retry delay.
2659 model_a.send_last_completion_stream_error(
2660 LanguageModelCompletionError::UpstreamProviderError {
2661 message: "Internal server error".to_string(),
2662 status: http_client::StatusCode::INTERNAL_SERVER_ERROR,
2663 retry_after: None,
2664 },
2665 );
2666 model_a.end_last_completion_stream();
2667 cx.run_until_parked();
2668
2669 // The old completion was consumed; model_a has no pending requests yet because the
2670 // retry timer hasn't fired.
2671 assert_eq!(model_a.completion_count(), 0);
2672
2673 // Switch to model_b and send a new message. This cancels the old turn.
2674 let model_b = Arc::new(FakeLanguageModel::with_id_and_thinking(
2675 "fake", "model-b", "Model B", false,
2676 ));
2677 thread.update(cx, |thread, cx| {
2678 thread.set_model(model_b.clone(), cx);
2679 });
2680 let events_2 = thread
2681 .update(cx, |thread, cx| {
2682 thread.send(UserMessageId::new(), ["Continue"], cx)
2683 })
2684 .unwrap();
2685 cx.run_until_parked();
2686
2687 // model_b should have received its completion request.
2688 assert_eq!(model_b.as_fake().completion_count(), 1);
2689
2690 // Advance the clock well past the retry delay (BASE_RETRY_DELAY = 5s).
2691 cx.executor().advance_clock(Duration::from_secs(10));
2692 cx.run_until_parked();
2693
2694 // model_a must NOT have received another completion request — the cancelled turn
2695 // should have exited during the retry delay rather than retrying with the old model.
2696 assert_eq!(
2697 model_a.completion_count(),
2698 0,
2699 "old model should not receive a retry request after cancellation"
2700 );
2701
2702 // Complete model_b's turn.
2703 model_b
2704 .as_fake()
2705 .send_last_completion_stream_text_chunk("Done!");
2706 model_b
2707 .as_fake()
2708 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2709 model_b.as_fake().end_last_completion_stream();
2710
2711 let events_1 = events_1.collect::<Vec<_>>().await;
2712 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2713
2714 let events_2 = events_2.collect::<Vec<_>>().await;
2715 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2716}
2717
2718#[gpui::test]
2719async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2720 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2721 let fake_model = model.as_fake();
2722
2723 let events_1 = thread
2724 .update(cx, |thread, cx| {
2725 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2726 })
2727 .unwrap();
2728 cx.run_until_parked();
2729 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2730 fake_model
2731 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2732 fake_model.end_last_completion_stream();
2733 let events_1 = events_1.collect::<Vec<_>>().await;
2734
2735 let events_2 = thread
2736 .update(cx, |thread, cx| {
2737 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2738 })
2739 .unwrap();
2740 cx.run_until_parked();
2741 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2742 fake_model
2743 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2744 fake_model.end_last_completion_stream();
2745 let events_2 = events_2.collect::<Vec<_>>().await;
2746
2747 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2748 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2749}
2750
2751#[gpui::test]
2752async fn test_refusal(cx: &mut TestAppContext) {
2753 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2754 let fake_model = model.as_fake();
2755
2756 let events = thread
2757 .update(cx, |thread, cx| {
2758 thread.send(UserMessageId::new(), ["Hello"], cx)
2759 })
2760 .unwrap();
2761 cx.run_until_parked();
2762 thread.read_with(cx, |thread, _| {
2763 assert_eq!(
2764 thread.to_markdown(),
2765 indoc! {"
2766 ## User
2767
2768 Hello
2769 "}
2770 );
2771 });
2772
2773 fake_model.send_last_completion_stream_text_chunk("Hey!");
2774 cx.run_until_parked();
2775 thread.read_with(cx, |thread, _| {
2776 assert_eq!(
2777 thread.to_markdown(),
2778 indoc! {"
2779 ## User
2780
2781 Hello
2782
2783 ## Assistant
2784
2785 Hey!
2786 "}
2787 );
2788 });
2789
2790 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2791 fake_model
2792 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2793 let events = events.collect::<Vec<_>>().await;
2794 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2795 thread.read_with(cx, |thread, _| {
2796 assert_eq!(thread.to_markdown(), "");
2797 });
2798}
2799
2800#[gpui::test]
2801async fn test_truncate_first_message(cx: &mut TestAppContext) {
2802 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2803 let fake_model = model.as_fake();
2804
2805 let message_id = UserMessageId::new();
2806 thread
2807 .update(cx, |thread, cx| {
2808 thread.send(message_id.clone(), ["Hello"], cx)
2809 })
2810 .unwrap();
2811 cx.run_until_parked();
2812 thread.read_with(cx, |thread, _| {
2813 assert_eq!(
2814 thread.to_markdown(),
2815 indoc! {"
2816 ## User
2817
2818 Hello
2819 "}
2820 );
2821 assert_eq!(thread.latest_token_usage(), None);
2822 });
2823
2824 fake_model.send_last_completion_stream_text_chunk("Hey!");
2825 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2826 language_model::TokenUsage {
2827 input_tokens: 32_000,
2828 output_tokens: 16_000,
2829 cache_creation_input_tokens: 0,
2830 cache_read_input_tokens: 0,
2831 },
2832 ));
2833 cx.run_until_parked();
2834 thread.read_with(cx, |thread, _| {
2835 assert_eq!(
2836 thread.to_markdown(),
2837 indoc! {"
2838 ## User
2839
2840 Hello
2841
2842 ## Assistant
2843
2844 Hey!
2845 "}
2846 );
2847 assert_eq!(
2848 thread.latest_token_usage(),
2849 Some(acp_thread::TokenUsage {
2850 used_tokens: 32_000 + 16_000,
2851 max_tokens: 1_000_000,
2852 max_output_tokens: None,
2853 input_tokens: 32_000,
2854 output_tokens: 16_000,
2855 })
2856 );
2857 });
2858
2859 thread
2860 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2861 .unwrap();
2862 cx.run_until_parked();
2863 thread.read_with(cx, |thread, _| {
2864 assert_eq!(thread.to_markdown(), "");
2865 assert_eq!(thread.latest_token_usage(), None);
2866 });
2867
2868 // Ensure we can still send a new message after truncation.
2869 thread
2870 .update(cx, |thread, cx| {
2871 thread.send(UserMessageId::new(), ["Hi"], cx)
2872 })
2873 .unwrap();
2874 thread.update(cx, |thread, _cx| {
2875 assert_eq!(
2876 thread.to_markdown(),
2877 indoc! {"
2878 ## User
2879
2880 Hi
2881 "}
2882 );
2883 });
2884 cx.run_until_parked();
2885 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2886 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2887 language_model::TokenUsage {
2888 input_tokens: 40_000,
2889 output_tokens: 20_000,
2890 cache_creation_input_tokens: 0,
2891 cache_read_input_tokens: 0,
2892 },
2893 ));
2894 cx.run_until_parked();
2895 thread.read_with(cx, |thread, _| {
2896 assert_eq!(
2897 thread.to_markdown(),
2898 indoc! {"
2899 ## User
2900
2901 Hi
2902
2903 ## Assistant
2904
2905 Ahoy!
2906 "}
2907 );
2908
2909 assert_eq!(
2910 thread.latest_token_usage(),
2911 Some(acp_thread::TokenUsage {
2912 used_tokens: 40_000 + 20_000,
2913 max_tokens: 1_000_000,
2914 max_output_tokens: None,
2915 input_tokens: 40_000,
2916 output_tokens: 20_000,
2917 })
2918 );
2919 });
2920}
2921
2922#[gpui::test]
2923async fn test_truncate_second_message(cx: &mut TestAppContext) {
2924 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2925 let fake_model = model.as_fake();
2926
2927 thread
2928 .update(cx, |thread, cx| {
2929 thread.send(UserMessageId::new(), ["Message 1"], cx)
2930 })
2931 .unwrap();
2932 cx.run_until_parked();
2933 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2934 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2935 language_model::TokenUsage {
2936 input_tokens: 32_000,
2937 output_tokens: 16_000,
2938 cache_creation_input_tokens: 0,
2939 cache_read_input_tokens: 0,
2940 },
2941 ));
2942 fake_model.end_last_completion_stream();
2943 cx.run_until_parked();
2944
2945 let assert_first_message_state = |cx: &mut TestAppContext| {
2946 thread.clone().read_with(cx, |thread, _| {
2947 assert_eq!(
2948 thread.to_markdown(),
2949 indoc! {"
2950 ## User
2951
2952 Message 1
2953
2954 ## Assistant
2955
2956 Message 1 response
2957 "}
2958 );
2959
2960 assert_eq!(
2961 thread.latest_token_usage(),
2962 Some(acp_thread::TokenUsage {
2963 used_tokens: 32_000 + 16_000,
2964 max_tokens: 1_000_000,
2965 max_output_tokens: None,
2966 input_tokens: 32_000,
2967 output_tokens: 16_000,
2968 })
2969 );
2970 });
2971 };
2972
2973 assert_first_message_state(cx);
2974
2975 let second_message_id = UserMessageId::new();
2976 thread
2977 .update(cx, |thread, cx| {
2978 thread.send(second_message_id.clone(), ["Message 2"], cx)
2979 })
2980 .unwrap();
2981 cx.run_until_parked();
2982
2983 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2984 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2985 language_model::TokenUsage {
2986 input_tokens: 40_000,
2987 output_tokens: 20_000,
2988 cache_creation_input_tokens: 0,
2989 cache_read_input_tokens: 0,
2990 },
2991 ));
2992 fake_model.end_last_completion_stream();
2993 cx.run_until_parked();
2994
2995 thread.read_with(cx, |thread, _| {
2996 assert_eq!(
2997 thread.to_markdown(),
2998 indoc! {"
2999 ## User
3000
3001 Message 1
3002
3003 ## Assistant
3004
3005 Message 1 response
3006
3007 ## User
3008
3009 Message 2
3010
3011 ## Assistant
3012
3013 Message 2 response
3014 "}
3015 );
3016
3017 assert_eq!(
3018 thread.latest_token_usage(),
3019 Some(acp_thread::TokenUsage {
3020 used_tokens: 40_000 + 20_000,
3021 max_tokens: 1_000_000,
3022 max_output_tokens: None,
3023 input_tokens: 40_000,
3024 output_tokens: 20_000,
3025 })
3026 );
3027 });
3028
3029 thread
3030 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
3031 .unwrap();
3032 cx.run_until_parked();
3033
3034 assert_first_message_state(cx);
3035}
3036
3037#[gpui::test]
3038async fn test_title_generation(cx: &mut TestAppContext) {
3039 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3040 let fake_model = model.as_fake();
3041
3042 let summary_model = Arc::new(FakeLanguageModel::default());
3043 thread.update(cx, |thread, cx| {
3044 thread.set_summarization_model(Some(summary_model.clone()), cx)
3045 });
3046
3047 let send = thread
3048 .update(cx, |thread, cx| {
3049 thread.send(UserMessageId::new(), ["Hello"], cx)
3050 })
3051 .unwrap();
3052 cx.run_until_parked();
3053
3054 fake_model.send_last_completion_stream_text_chunk("Hey!");
3055 fake_model.end_last_completion_stream();
3056 cx.run_until_parked();
3057 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
3058
3059 // Ensure the summary model has been invoked to generate a title.
3060 summary_model.send_last_completion_stream_text_chunk("Hello ");
3061 summary_model.send_last_completion_stream_text_chunk("world\nG");
3062 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
3063 summary_model.end_last_completion_stream();
3064 send.collect::<Vec<_>>().await;
3065 cx.run_until_parked();
3066 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
3067
3068 // Send another message, ensuring no title is generated this time.
3069 let send = thread
3070 .update(cx, |thread, cx| {
3071 thread.send(UserMessageId::new(), ["Hello again"], cx)
3072 })
3073 .unwrap();
3074 cx.run_until_parked();
3075 fake_model.send_last_completion_stream_text_chunk("Hey again!");
3076 fake_model.end_last_completion_stream();
3077 cx.run_until_parked();
3078 assert_eq!(summary_model.pending_completions(), Vec::new());
3079 send.collect::<Vec<_>>().await;
3080 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
3081}
3082
3083#[gpui::test]
3084async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
3085 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3086 let fake_model = model.as_fake();
3087
3088 let _events = thread
3089 .update(cx, |thread, cx| {
3090 thread.add_tool(ToolRequiringPermission);
3091 thread.add_tool(EchoTool);
3092 thread.send(UserMessageId::new(), ["Hey!"], cx)
3093 })
3094 .unwrap();
3095 cx.run_until_parked();
3096
3097 let permission_tool_use = LanguageModelToolUse {
3098 id: "tool_id_1".into(),
3099 name: ToolRequiringPermission::NAME.into(),
3100 raw_input: "{}".into(),
3101 input: json!({}),
3102 is_input_complete: true,
3103 thought_signature: None,
3104 };
3105 let echo_tool_use = LanguageModelToolUse {
3106 id: "tool_id_2".into(),
3107 name: EchoTool::NAME.into(),
3108 raw_input: json!({"text": "test"}).to_string(),
3109 input: json!({"text": "test"}),
3110 is_input_complete: true,
3111 thought_signature: None,
3112 };
3113 fake_model.send_last_completion_stream_text_chunk("Hi!");
3114 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3115 permission_tool_use,
3116 ));
3117 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3118 echo_tool_use.clone(),
3119 ));
3120 fake_model.end_last_completion_stream();
3121 cx.run_until_parked();
3122
3123 // Ensure pending tools are skipped when building a request.
3124 let request = thread
3125 .read_with(cx, |thread, cx| {
3126 thread.build_completion_request(CompletionIntent::EditFile, cx)
3127 })
3128 .unwrap();
3129 assert_eq!(
3130 request.messages[1..],
3131 vec![
3132 LanguageModelRequestMessage {
3133 role: Role::User,
3134 content: vec!["Hey!".into()],
3135 cache: true,
3136 reasoning_details: None,
3137 },
3138 LanguageModelRequestMessage {
3139 role: Role::Assistant,
3140 content: vec![
3141 MessageContent::Text("Hi!".into()),
3142 MessageContent::ToolUse(echo_tool_use.clone())
3143 ],
3144 cache: false,
3145 reasoning_details: None,
3146 },
3147 LanguageModelRequestMessage {
3148 role: Role::User,
3149 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
3150 tool_use_id: echo_tool_use.id.clone(),
3151 tool_name: echo_tool_use.name,
3152 is_error: false,
3153 content: "test".into(),
3154 output: Some("test".into())
3155 })],
3156 cache: false,
3157 reasoning_details: None,
3158 },
3159 ],
3160 );
3161}
3162
3163#[gpui::test]
3164async fn test_agent_connection(cx: &mut TestAppContext) {
3165 cx.update(settings::init);
3166 let templates = Templates::new();
3167
3168 // Initialize language model system with test provider
3169 cx.update(|cx| {
3170 gpui_tokio::init(cx);
3171
3172 let http_client = FakeHttpClient::with_404_response();
3173 let clock = Arc::new(clock::FakeSystemClock::new());
3174 let client = Client::new(clock, http_client, cx);
3175 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3176 language_model::init(user_store.clone(), client.clone(), cx);
3177 language_models::init(user_store, client.clone(), cx);
3178 LanguageModelRegistry::test(cx);
3179 });
3180 cx.executor().forbid_parking();
3181
3182 // Create a project for new_thread
3183 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
3184 fake_fs.insert_tree(path!("/test"), json!({})).await;
3185 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
3186 let cwd = PathList::new(&[Path::new("/test")]);
3187 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3188
3189 // Create agent and connection
3190 let agent = cx
3191 .update(|cx| NativeAgent::new(thread_store, templates.clone(), None, fake_fs.clone(), cx));
3192 let connection = NativeAgentConnection(agent.clone());
3193
3194 // Create a thread using new_thread
3195 let connection_rc = Rc::new(connection.clone());
3196 let acp_thread = cx
3197 .update(|cx| connection_rc.new_session(project, cwd, cx))
3198 .await
3199 .expect("new_thread should succeed");
3200
3201 // Get the session_id from the AcpThread
3202 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3203
3204 // Test model_selector returns Some
3205 let selector_opt = connection.model_selector(&session_id);
3206 assert!(
3207 selector_opt.is_some(),
3208 "agent should always support ModelSelector"
3209 );
3210 let selector = selector_opt.unwrap();
3211
3212 // Test list_models
3213 let listed_models = cx
3214 .update(|cx| selector.list_models(cx))
3215 .await
3216 .expect("list_models should succeed");
3217 let AgentModelList::Grouped(listed_models) = listed_models else {
3218 panic!("Unexpected model list type");
3219 };
3220 assert!(!listed_models.is_empty(), "should have at least one model");
3221 assert_eq!(
3222 listed_models[&AgentModelGroupName("Fake".into())][0]
3223 .id
3224 .0
3225 .as_ref(),
3226 "fake/fake"
3227 );
3228
3229 // Test selected_model returns the default
3230 let model = cx
3231 .update(|cx| selector.selected_model(cx))
3232 .await
3233 .expect("selected_model should succeed");
3234 let model = cx
3235 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
3236 .unwrap();
3237 let model = model.as_fake();
3238 assert_eq!(model.id().0, "fake", "should return default model");
3239
3240 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
3241 cx.run_until_parked();
3242 model.send_last_completion_stream_text_chunk("def");
3243 cx.run_until_parked();
3244 acp_thread.read_with(cx, |thread, cx| {
3245 assert_eq!(
3246 thread.to_markdown(cx),
3247 indoc! {"
3248 ## User
3249
3250 abc
3251
3252 ## Assistant
3253
3254 def
3255
3256 "}
3257 )
3258 });
3259
3260 // Test cancel
3261 cx.update(|cx| connection.cancel(&session_id, cx));
3262 request.await.expect("prompt should fail gracefully");
3263
3264 // Explicitly close the session and drop the ACP thread.
3265 cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
3266 .await
3267 .unwrap();
3268 drop(acp_thread);
3269 let result = cx
3270 .update(|cx| {
3271 connection.prompt(
3272 Some(acp_thread::UserMessageId::new()),
3273 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
3274 cx,
3275 )
3276 })
3277 .await;
3278 assert_eq!(
3279 result.as_ref().unwrap_err().to_string(),
3280 "Session not found",
3281 "unexpected result: {:?}",
3282 result
3283 );
3284}
3285
3286#[gpui::test]
3287async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
3288 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3289 thread.update(cx, |thread, _cx| thread.add_tool(EchoTool));
3290 let fake_model = model.as_fake();
3291
3292 let mut events = thread
3293 .update(cx, |thread, cx| {
3294 thread.send(UserMessageId::new(), ["Echo something"], cx)
3295 })
3296 .unwrap();
3297 cx.run_until_parked();
3298
3299 // Simulate streaming partial input.
3300 let input = json!({});
3301 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3302 LanguageModelToolUse {
3303 id: "1".into(),
3304 name: EchoTool::NAME.into(),
3305 raw_input: input.to_string(),
3306 input,
3307 is_input_complete: false,
3308 thought_signature: None,
3309 },
3310 ));
3311
3312 // Input streaming completed
3313 let input = json!({ "text": "Hello!" });
3314 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3315 LanguageModelToolUse {
3316 id: "1".into(),
3317 name: "echo".into(),
3318 raw_input: input.to_string(),
3319 input,
3320 is_input_complete: true,
3321 thought_signature: None,
3322 },
3323 ));
3324 fake_model.end_last_completion_stream();
3325 cx.run_until_parked();
3326
3327 let tool_call = expect_tool_call(&mut events).await;
3328 assert_eq!(
3329 tool_call,
3330 acp::ToolCall::new("1", "Echo")
3331 .raw_input(json!({}))
3332 .meta(acp::Meta::from_iter([("tool_name".into(), "echo".into())]))
3333 );
3334 let update = expect_tool_call_update_fields(&mut events).await;
3335 assert_eq!(
3336 update,
3337 acp::ToolCallUpdate::new(
3338 "1",
3339 acp::ToolCallUpdateFields::new()
3340 .title("Echo")
3341 .kind(acp::ToolKind::Other)
3342 .raw_input(json!({ "text": "Hello!"}))
3343 )
3344 );
3345 let update = expect_tool_call_update_fields(&mut events).await;
3346 assert_eq!(
3347 update,
3348 acp::ToolCallUpdate::new(
3349 "1",
3350 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3351 )
3352 );
3353 let update = expect_tool_call_update_fields(&mut events).await;
3354 assert_eq!(
3355 update,
3356 acp::ToolCallUpdate::new(
3357 "1",
3358 acp::ToolCallUpdateFields::new()
3359 .status(acp::ToolCallStatus::Completed)
3360 .raw_output("Hello!")
3361 )
3362 );
3363}
3364
3365#[gpui::test]
3366async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3367 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3368 let fake_model = model.as_fake();
3369
3370 let mut events = thread
3371 .update(cx, |thread, cx| {
3372 thread.send(UserMessageId::new(), ["Hello!"], cx)
3373 })
3374 .unwrap();
3375 cx.run_until_parked();
3376
3377 fake_model.send_last_completion_stream_text_chunk("Hey!");
3378 fake_model.end_last_completion_stream();
3379
3380 let mut retry_events = Vec::new();
3381 while let Some(Ok(event)) = events.next().await {
3382 match event {
3383 ThreadEvent::Retry(retry_status) => {
3384 retry_events.push(retry_status);
3385 }
3386 ThreadEvent::Stop(..) => break,
3387 _ => {}
3388 }
3389 }
3390
3391 assert_eq!(retry_events.len(), 0);
3392 thread.read_with(cx, |thread, _cx| {
3393 assert_eq!(
3394 thread.to_markdown(),
3395 indoc! {"
3396 ## User
3397
3398 Hello!
3399
3400 ## Assistant
3401
3402 Hey!
3403 "}
3404 )
3405 });
3406}
3407
3408#[gpui::test]
3409async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3410 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3411 let fake_model = model.as_fake();
3412
3413 let mut events = thread
3414 .update(cx, |thread, cx| {
3415 thread.send(UserMessageId::new(), ["Hello!"], cx)
3416 })
3417 .unwrap();
3418 cx.run_until_parked();
3419
3420 fake_model.send_last_completion_stream_text_chunk("Hey,");
3421 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3422 provider: LanguageModelProviderName::new("Anthropic"),
3423 retry_after: Some(Duration::from_secs(3)),
3424 });
3425 fake_model.end_last_completion_stream();
3426
3427 cx.executor().advance_clock(Duration::from_secs(3));
3428 cx.run_until_parked();
3429
3430 fake_model.send_last_completion_stream_text_chunk("there!");
3431 fake_model.end_last_completion_stream();
3432 cx.run_until_parked();
3433
3434 let mut retry_events = Vec::new();
3435 while let Some(Ok(event)) = events.next().await {
3436 match event {
3437 ThreadEvent::Retry(retry_status) => {
3438 retry_events.push(retry_status);
3439 }
3440 ThreadEvent::Stop(..) => break,
3441 _ => {}
3442 }
3443 }
3444
3445 assert_eq!(retry_events.len(), 1);
3446 assert!(matches!(
3447 retry_events[0],
3448 acp_thread::RetryStatus { attempt: 1, .. }
3449 ));
3450 thread.read_with(cx, |thread, _cx| {
3451 assert_eq!(
3452 thread.to_markdown(),
3453 indoc! {"
3454 ## User
3455
3456 Hello!
3457
3458 ## Assistant
3459
3460 Hey,
3461
3462 [resume]
3463
3464 ## Assistant
3465
3466 there!
3467 "}
3468 )
3469 });
3470}
3471
3472#[gpui::test]
3473async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3474 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3475 let fake_model = model.as_fake();
3476
3477 let events = thread
3478 .update(cx, |thread, cx| {
3479 thread.add_tool(EchoTool);
3480 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3481 })
3482 .unwrap();
3483 cx.run_until_parked();
3484
3485 let tool_use_1 = LanguageModelToolUse {
3486 id: "tool_1".into(),
3487 name: EchoTool::NAME.into(),
3488 raw_input: json!({"text": "test"}).to_string(),
3489 input: json!({"text": "test"}),
3490 is_input_complete: true,
3491 thought_signature: None,
3492 };
3493 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3494 tool_use_1.clone(),
3495 ));
3496 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3497 provider: LanguageModelProviderName::new("Anthropic"),
3498 retry_after: Some(Duration::from_secs(3)),
3499 });
3500 fake_model.end_last_completion_stream();
3501
3502 cx.executor().advance_clock(Duration::from_secs(3));
3503 let completion = fake_model.pending_completions().pop().unwrap();
3504 assert_eq!(
3505 completion.messages[1..],
3506 vec![
3507 LanguageModelRequestMessage {
3508 role: Role::User,
3509 content: vec!["Call the echo tool!".into()],
3510 cache: false,
3511 reasoning_details: None,
3512 },
3513 LanguageModelRequestMessage {
3514 role: Role::Assistant,
3515 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3516 cache: false,
3517 reasoning_details: None,
3518 },
3519 LanguageModelRequestMessage {
3520 role: Role::User,
3521 content: vec![language_model::MessageContent::ToolResult(
3522 LanguageModelToolResult {
3523 tool_use_id: tool_use_1.id.clone(),
3524 tool_name: tool_use_1.name.clone(),
3525 is_error: false,
3526 content: "test".into(),
3527 output: Some("test".into())
3528 }
3529 )],
3530 cache: true,
3531 reasoning_details: None,
3532 },
3533 ]
3534 );
3535
3536 fake_model.send_last_completion_stream_text_chunk("Done");
3537 fake_model.end_last_completion_stream();
3538 cx.run_until_parked();
3539 events.collect::<Vec<_>>().await;
3540 thread.read_with(cx, |thread, _cx| {
3541 assert_eq!(
3542 thread.last_received_or_pending_message(),
3543 Some(Message::Agent(AgentMessage {
3544 content: vec![AgentMessageContent::Text("Done".into())],
3545 tool_results: IndexMap::default(),
3546 reasoning_details: None,
3547 }))
3548 );
3549 })
3550}
3551
3552#[gpui::test]
3553async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3554 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3555 let fake_model = model.as_fake();
3556
3557 let mut events = thread
3558 .update(cx, |thread, cx| {
3559 thread.send(UserMessageId::new(), ["Hello!"], cx)
3560 })
3561 .unwrap();
3562 cx.run_until_parked();
3563
3564 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3565 fake_model.send_last_completion_stream_error(
3566 LanguageModelCompletionError::ServerOverloaded {
3567 provider: LanguageModelProviderName::new("Anthropic"),
3568 retry_after: Some(Duration::from_secs(3)),
3569 },
3570 );
3571 fake_model.end_last_completion_stream();
3572 cx.executor().advance_clock(Duration::from_secs(3));
3573 cx.run_until_parked();
3574 }
3575
3576 let mut errors = Vec::new();
3577 let mut retry_events = Vec::new();
3578 while let Some(event) = events.next().await {
3579 match event {
3580 Ok(ThreadEvent::Retry(retry_status)) => {
3581 retry_events.push(retry_status);
3582 }
3583 Ok(ThreadEvent::Stop(..)) => break,
3584 Err(error) => errors.push(error),
3585 _ => {}
3586 }
3587 }
3588
3589 assert_eq!(
3590 retry_events.len(),
3591 crate::thread::MAX_RETRY_ATTEMPTS as usize
3592 );
3593 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3594 assert_eq!(retry_events[i].attempt, i + 1);
3595 }
3596 assert_eq!(errors.len(), 1);
3597 let error = errors[0]
3598 .downcast_ref::<LanguageModelCompletionError>()
3599 .unwrap();
3600 assert!(matches!(
3601 error,
3602 LanguageModelCompletionError::ServerOverloaded { .. }
3603 ));
3604}
3605
3606#[gpui::test]
3607async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input(
3608 cx: &mut TestAppContext,
3609) {
3610 init_test(cx);
3611 always_allow_tools(cx);
3612
3613 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3614 let fake_model = model.as_fake();
3615
3616 thread.update(cx, |thread, _cx| {
3617 thread.add_tool(StreamingEchoTool::new());
3618 });
3619
3620 let _events = thread
3621 .update(cx, |thread, cx| {
3622 thread.send(UserMessageId::new(), ["Use the streaming_echo tool"], cx)
3623 })
3624 .unwrap();
3625 cx.run_until_parked();
3626
3627 // Send a partial tool use (is_input_complete = false), simulating the LLM
3628 // streaming input for a tool.
3629 let tool_use = LanguageModelToolUse {
3630 id: "tool_1".into(),
3631 name: "streaming_echo".into(),
3632 raw_input: r#"{"text": "partial"}"#.into(),
3633 input: json!({"text": "partial"}),
3634 is_input_complete: false,
3635 thought_signature: None,
3636 };
3637 fake_model
3638 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
3639 cx.run_until_parked();
3640
3641 // Send a stream error WITHOUT ever sending is_input_complete = true.
3642 // Before the fix, this would deadlock: the tool waits for more partials
3643 // (or cancellation), run_turn_internal waits for the tool, and the sender
3644 // keeping the channel open lives inside RunningTurn.
3645 fake_model.send_last_completion_stream_error(
3646 LanguageModelCompletionError::UpstreamProviderError {
3647 message: "Internal server error".to_string(),
3648 status: http_client::StatusCode::INTERNAL_SERVER_ERROR,
3649 retry_after: None,
3650 },
3651 );
3652 fake_model.end_last_completion_stream();
3653
3654 // Advance past the retry delay so run_turn_internal retries.
3655 cx.executor().advance_clock(Duration::from_secs(5));
3656 cx.run_until_parked();
3657
3658 // The retry request should contain the streaming tool's error result,
3659 // proving the tool terminated and its result was forwarded.
3660 let completion = fake_model
3661 .pending_completions()
3662 .pop()
3663 .expect("No running turn");
3664 assert_eq!(
3665 completion.messages[1..],
3666 vec![
3667 LanguageModelRequestMessage {
3668 role: Role::User,
3669 content: vec!["Use the streaming_echo tool".into()],
3670 cache: false,
3671 reasoning_details: None,
3672 },
3673 LanguageModelRequestMessage {
3674 role: Role::Assistant,
3675 content: vec![language_model::MessageContent::ToolUse(tool_use.clone())],
3676 cache: false,
3677 reasoning_details: None,
3678 },
3679 LanguageModelRequestMessage {
3680 role: Role::User,
3681 content: vec![language_model::MessageContent::ToolResult(
3682 LanguageModelToolResult {
3683 tool_use_id: tool_use.id.clone(),
3684 tool_name: tool_use.name,
3685 is_error: true,
3686 content: "Failed to receive tool input: tool input was not fully received"
3687 .into(),
3688 output: Some(
3689 "Failed to receive tool input: tool input was not fully received"
3690 .into()
3691 ),
3692 }
3693 )],
3694 cache: true,
3695 reasoning_details: None,
3696 },
3697 ]
3698 );
3699
3700 // Finish the retry round so the turn completes cleanly.
3701 fake_model.send_last_completion_stream_text_chunk("Done");
3702 fake_model.end_last_completion_stream();
3703 cx.run_until_parked();
3704
3705 thread.read_with(cx, |thread, _cx| {
3706 assert!(
3707 thread.is_turn_complete(),
3708 "Thread should not be stuck; the turn should have completed",
3709 );
3710 });
3711}
3712
3713/// Filters out the stop events for asserting against in tests
3714fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3715 result_events
3716 .into_iter()
3717 .filter_map(|event| match event.unwrap() {
3718 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3719 _ => None,
3720 })
3721 .collect()
3722}
3723
3724struct ThreadTest {
3725 model: Arc<dyn LanguageModel>,
3726 thread: Entity<Thread>,
3727 project_context: Entity<ProjectContext>,
3728 context_server_store: Entity<ContextServerStore>,
3729 fs: Arc<FakeFs>,
3730}
3731
3732enum TestModel {
3733 Sonnet4,
3734 Fake,
3735}
3736
3737impl TestModel {
3738 fn id(&self) -> LanguageModelId {
3739 match self {
3740 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3741 TestModel::Fake => unreachable!(),
3742 }
3743 }
3744}
3745
3746async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3747 cx.executor().allow_parking();
3748
3749 let fs = FakeFs::new(cx.background_executor.clone());
3750 fs.create_dir(paths::settings_file().parent().unwrap())
3751 .await
3752 .unwrap();
3753 fs.insert_file(
3754 paths::settings_file(),
3755 json!({
3756 "agent": {
3757 "default_profile": "test-profile",
3758 "profiles": {
3759 "test-profile": {
3760 "name": "Test Profile",
3761 "tools": {
3762 EchoTool::NAME: true,
3763 DelayTool::NAME: true,
3764 WordListTool::NAME: true,
3765 ToolRequiringPermission::NAME: true,
3766 InfiniteTool::NAME: true,
3767 CancellationAwareTool::NAME: true,
3768 StreamingEchoTool::NAME: true,
3769 StreamingFailingEchoTool::NAME: true,
3770 TerminalTool::NAME: true,
3771 }
3772 }
3773 }
3774 }
3775 })
3776 .to_string()
3777 .into_bytes(),
3778 )
3779 .await;
3780
3781 cx.update(|cx| {
3782 settings::init(cx);
3783
3784 match model {
3785 TestModel::Fake => {}
3786 TestModel::Sonnet4 => {
3787 gpui_tokio::init(cx);
3788 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3789 cx.set_http_client(Arc::new(http_client));
3790 let client = Client::production(cx);
3791 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3792 language_model::init(user_store.clone(), client.clone(), cx);
3793 language_models::init(user_store, client.clone(), cx);
3794 }
3795 };
3796
3797 watch_settings(fs.clone(), cx);
3798 });
3799
3800 let templates = Templates::new();
3801
3802 fs.insert_tree(path!("/test"), json!({})).await;
3803 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3804
3805 let model = cx
3806 .update(|cx| {
3807 if let TestModel::Fake = model {
3808 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3809 } else {
3810 let model_id = model.id();
3811 let models = LanguageModelRegistry::read_global(cx);
3812 let model = models
3813 .available_models(cx)
3814 .find(|model| model.id() == model_id)
3815 .unwrap();
3816
3817 let provider = models.provider(&model.provider_id()).unwrap();
3818 let authenticated = provider.authenticate(cx);
3819
3820 cx.spawn(async move |_cx| {
3821 authenticated.await.unwrap();
3822 model
3823 })
3824 }
3825 })
3826 .await;
3827
3828 let project_context = cx.new(|_cx| ProjectContext::default());
3829 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3830 let context_server_registry =
3831 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3832 let thread = cx.new(|cx| {
3833 Thread::new(
3834 project,
3835 project_context.clone(),
3836 context_server_registry,
3837 templates,
3838 Some(model.clone()),
3839 cx,
3840 )
3841 });
3842 ThreadTest {
3843 model,
3844 thread,
3845 project_context,
3846 context_server_store,
3847 fs,
3848 }
3849}
3850
3851#[cfg(test)]
3852#[ctor::ctor]
3853fn init_logger() {
3854 if std::env::var("RUST_LOG").is_ok() {
3855 env_logger::init();
3856 }
3857}
3858
3859fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3860 let fs = fs.clone();
3861 cx.spawn({
3862 async move |cx| {
3863 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3864 cx.background_executor(),
3865 fs,
3866 paths::settings_file().clone(),
3867 );
3868 let _watcher_task = watcher_task;
3869
3870 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3871 cx.update(|cx| {
3872 SettingsStore::update_global(cx, |settings, cx| {
3873 settings.set_user_settings(&new_settings_content, cx)
3874 })
3875 })
3876 .ok();
3877 }
3878 }
3879 })
3880 .detach();
3881}
3882
3883fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3884 completion
3885 .tools
3886 .iter()
3887 .map(|tool| tool.name.clone())
3888 .collect()
3889}
3890
3891fn setup_context_server(
3892 name: &'static str,
3893 tools: Vec<context_server::types::Tool>,
3894 context_server_store: &Entity<ContextServerStore>,
3895 cx: &mut TestAppContext,
3896) -> mpsc::UnboundedReceiver<(
3897 context_server::types::CallToolParams,
3898 oneshot::Sender<context_server::types::CallToolResponse>,
3899)> {
3900 cx.update(|cx| {
3901 let mut settings = ProjectSettings::get_global(cx).clone();
3902 settings.context_servers.insert(
3903 name.into(),
3904 project::project_settings::ContextServerSettings::Stdio {
3905 enabled: true,
3906 remote: false,
3907 command: ContextServerCommand {
3908 path: "somebinary".into(),
3909 args: Vec::new(),
3910 env: None,
3911 timeout: None,
3912 },
3913 },
3914 );
3915 ProjectSettings::override_global(settings, cx);
3916 });
3917
3918 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3919 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3920 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3921 context_server::types::InitializeResponse {
3922 protocol_version: context_server::types::ProtocolVersion(
3923 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3924 ),
3925 server_info: context_server::types::Implementation {
3926 name: name.into(),
3927 version: "1.0.0".to_string(),
3928 },
3929 capabilities: context_server::types::ServerCapabilities {
3930 tools: Some(context_server::types::ToolsCapabilities {
3931 list_changed: Some(true),
3932 }),
3933 ..Default::default()
3934 },
3935 meta: None,
3936 }
3937 })
3938 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3939 let tools = tools.clone();
3940 async move {
3941 context_server::types::ListToolsResponse {
3942 tools,
3943 next_cursor: None,
3944 meta: None,
3945 }
3946 }
3947 })
3948 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3949 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3950 async move {
3951 let (response_tx, response_rx) = oneshot::channel();
3952 mcp_tool_calls_tx
3953 .unbounded_send((params, response_tx))
3954 .unwrap();
3955 response_rx.await.unwrap()
3956 }
3957 });
3958 context_server_store.update(cx, |store, cx| {
3959 store.start_server(
3960 Arc::new(ContextServer::new(
3961 ContextServerId(name.into()),
3962 Arc::new(fake_transport),
3963 )),
3964 cx,
3965 );
3966 });
3967 cx.run_until_parked();
3968 mcp_tool_calls_rx
3969}
3970
3971#[gpui::test]
3972async fn test_tokens_before_message(cx: &mut TestAppContext) {
3973 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3974 let fake_model = model.as_fake();
3975
3976 // First message
3977 let message_1_id = UserMessageId::new();
3978 thread
3979 .update(cx, |thread, cx| {
3980 thread.send(message_1_id.clone(), ["First message"], cx)
3981 })
3982 .unwrap();
3983 cx.run_until_parked();
3984
3985 // Before any response, tokens_before_message should return None for first message
3986 thread.read_with(cx, |thread, _| {
3987 assert_eq!(
3988 thread.tokens_before_message(&message_1_id),
3989 None,
3990 "First message should have no tokens before it"
3991 );
3992 });
3993
3994 // Complete first message with usage
3995 fake_model.send_last_completion_stream_text_chunk("Response 1");
3996 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3997 language_model::TokenUsage {
3998 input_tokens: 100,
3999 output_tokens: 50,
4000 cache_creation_input_tokens: 0,
4001 cache_read_input_tokens: 0,
4002 },
4003 ));
4004 fake_model.end_last_completion_stream();
4005 cx.run_until_parked();
4006
4007 // First message still has no tokens before it
4008 thread.read_with(cx, |thread, _| {
4009 assert_eq!(
4010 thread.tokens_before_message(&message_1_id),
4011 None,
4012 "First message should still have no tokens before it after response"
4013 );
4014 });
4015
4016 // Second message
4017 let message_2_id = UserMessageId::new();
4018 thread
4019 .update(cx, |thread, cx| {
4020 thread.send(message_2_id.clone(), ["Second message"], cx)
4021 })
4022 .unwrap();
4023 cx.run_until_parked();
4024
4025 // Second message should have first message's input tokens before it
4026 thread.read_with(cx, |thread, _| {
4027 assert_eq!(
4028 thread.tokens_before_message(&message_2_id),
4029 Some(100),
4030 "Second message should have 100 tokens before it (from first request)"
4031 );
4032 });
4033
4034 // Complete second message
4035 fake_model.send_last_completion_stream_text_chunk("Response 2");
4036 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
4037 language_model::TokenUsage {
4038 input_tokens: 250, // Total for this request (includes previous context)
4039 output_tokens: 75,
4040 cache_creation_input_tokens: 0,
4041 cache_read_input_tokens: 0,
4042 },
4043 ));
4044 fake_model.end_last_completion_stream();
4045 cx.run_until_parked();
4046
4047 // Third message
4048 let message_3_id = UserMessageId::new();
4049 thread
4050 .update(cx, |thread, cx| {
4051 thread.send(message_3_id.clone(), ["Third message"], cx)
4052 })
4053 .unwrap();
4054 cx.run_until_parked();
4055
4056 // Third message should have second message's input tokens (250) before it
4057 thread.read_with(cx, |thread, _| {
4058 assert_eq!(
4059 thread.tokens_before_message(&message_3_id),
4060 Some(250),
4061 "Third message should have 250 tokens before it (from second request)"
4062 );
4063 // Second message should still have 100
4064 assert_eq!(
4065 thread.tokens_before_message(&message_2_id),
4066 Some(100),
4067 "Second message should still have 100 tokens before it"
4068 );
4069 // First message still has none
4070 assert_eq!(
4071 thread.tokens_before_message(&message_1_id),
4072 None,
4073 "First message should still have no tokens before it"
4074 );
4075 });
4076}
4077
4078#[gpui::test]
4079async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
4080 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4081 let fake_model = model.as_fake();
4082
4083 // Set up three messages with responses
4084 let message_1_id = UserMessageId::new();
4085 thread
4086 .update(cx, |thread, cx| {
4087 thread.send(message_1_id.clone(), ["Message 1"], cx)
4088 })
4089 .unwrap();
4090 cx.run_until_parked();
4091 fake_model.send_last_completion_stream_text_chunk("Response 1");
4092 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
4093 language_model::TokenUsage {
4094 input_tokens: 100,
4095 output_tokens: 50,
4096 cache_creation_input_tokens: 0,
4097 cache_read_input_tokens: 0,
4098 },
4099 ));
4100 fake_model.end_last_completion_stream();
4101 cx.run_until_parked();
4102
4103 let message_2_id = UserMessageId::new();
4104 thread
4105 .update(cx, |thread, cx| {
4106 thread.send(message_2_id.clone(), ["Message 2"], cx)
4107 })
4108 .unwrap();
4109 cx.run_until_parked();
4110 fake_model.send_last_completion_stream_text_chunk("Response 2");
4111 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
4112 language_model::TokenUsage {
4113 input_tokens: 250,
4114 output_tokens: 75,
4115 cache_creation_input_tokens: 0,
4116 cache_read_input_tokens: 0,
4117 },
4118 ));
4119 fake_model.end_last_completion_stream();
4120 cx.run_until_parked();
4121
4122 // Verify initial state
4123 thread.read_with(cx, |thread, _| {
4124 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
4125 });
4126
4127 // Truncate at message 2 (removes message 2 and everything after)
4128 thread
4129 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
4130 .unwrap();
4131 cx.run_until_parked();
4132
4133 // After truncation, message_2_id no longer exists, so lookup should return None
4134 thread.read_with(cx, |thread, _| {
4135 assert_eq!(
4136 thread.tokens_before_message(&message_2_id),
4137 None,
4138 "After truncation, message 2 no longer exists"
4139 );
4140 // Message 1 still exists but has no tokens before it
4141 assert_eq!(
4142 thread.tokens_before_message(&message_1_id),
4143 None,
4144 "First message still has no tokens before it"
4145 );
4146 });
4147}
4148
4149#[gpui::test]
4150async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
4151 init_test(cx);
4152
4153 let fs = FakeFs::new(cx.executor());
4154 fs.insert_tree("/root", json!({})).await;
4155 let project = Project::test(fs, ["/root".as_ref()], cx).await;
4156
4157 // Test 1: Deny rule blocks command
4158 {
4159 let environment = Rc::new(cx.update(|cx| {
4160 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4161 }));
4162
4163 cx.update(|cx| {
4164 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4165 settings.tool_permissions.tools.insert(
4166 TerminalTool::NAME.into(),
4167 agent_settings::ToolRules {
4168 default: Some(settings::ToolPermissionMode::Confirm),
4169 always_allow: vec![],
4170 always_deny: vec![
4171 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
4172 ],
4173 always_confirm: vec![],
4174 invalid_patterns: vec![],
4175 },
4176 );
4177 agent_settings::AgentSettings::override_global(settings, cx);
4178 });
4179
4180 #[allow(clippy::arc_with_non_send_sync)]
4181 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4182 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4183
4184 let task = cx.update(|cx| {
4185 tool.run(
4186 ToolInput::resolved(crate::TerminalToolInput {
4187 command: "rm -rf /".to_string(),
4188 cd: ".".to_string(),
4189 timeout_ms: None,
4190 }),
4191 event_stream,
4192 cx,
4193 )
4194 });
4195
4196 let result = task.await;
4197 assert!(
4198 result.is_err(),
4199 "expected command to be blocked by deny rule"
4200 );
4201 let err_msg = result.unwrap_err().to_lowercase();
4202 assert!(
4203 err_msg.contains("blocked"),
4204 "error should mention the command was blocked"
4205 );
4206 }
4207
4208 // Test 2: Allow rule skips confirmation (and overrides default: Deny)
4209 {
4210 let environment = Rc::new(cx.update(|cx| {
4211 FakeThreadEnvironment::default()
4212 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4213 }));
4214
4215 cx.update(|cx| {
4216 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4217 settings.tool_permissions.tools.insert(
4218 TerminalTool::NAME.into(),
4219 agent_settings::ToolRules {
4220 default: Some(settings::ToolPermissionMode::Deny),
4221 always_allow: vec![
4222 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
4223 ],
4224 always_deny: vec![],
4225 always_confirm: vec![],
4226 invalid_patterns: vec![],
4227 },
4228 );
4229 agent_settings::AgentSettings::override_global(settings, cx);
4230 });
4231
4232 #[allow(clippy::arc_with_non_send_sync)]
4233 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4234 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4235
4236 let task = cx.update(|cx| {
4237 tool.run(
4238 ToolInput::resolved(crate::TerminalToolInput {
4239 command: "echo hello".to_string(),
4240 cd: ".".to_string(),
4241 timeout_ms: None,
4242 }),
4243 event_stream,
4244 cx,
4245 )
4246 });
4247
4248 let update = rx.expect_update_fields().await;
4249 assert!(
4250 update.content.iter().any(|blocks| {
4251 blocks
4252 .iter()
4253 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
4254 }),
4255 "expected terminal content (allow rule should skip confirmation and override default deny)"
4256 );
4257
4258 let result = task.await;
4259 assert!(
4260 result.is_ok(),
4261 "expected command to succeed without confirmation"
4262 );
4263 }
4264
4265 // Test 3: global default: allow does NOT override always_confirm patterns
4266 {
4267 let environment = Rc::new(cx.update(|cx| {
4268 FakeThreadEnvironment::default()
4269 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4270 }));
4271
4272 cx.update(|cx| {
4273 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4274 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4275 settings.tool_permissions.tools.insert(
4276 TerminalTool::NAME.into(),
4277 agent_settings::ToolRules {
4278 default: Some(settings::ToolPermissionMode::Allow),
4279 always_allow: vec![],
4280 always_deny: vec![],
4281 always_confirm: vec![
4282 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
4283 ],
4284 invalid_patterns: vec![],
4285 },
4286 );
4287 agent_settings::AgentSettings::override_global(settings, cx);
4288 });
4289
4290 #[allow(clippy::arc_with_non_send_sync)]
4291 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4292 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4293
4294 let _task = cx.update(|cx| {
4295 tool.run(
4296 ToolInput::resolved(crate::TerminalToolInput {
4297 command: "sudo rm file".to_string(),
4298 cd: ".".to_string(),
4299 timeout_ms: None,
4300 }),
4301 event_stream,
4302 cx,
4303 )
4304 });
4305
4306 // With global default: allow, confirm patterns are still respected
4307 // The expect_authorization() call will panic if no authorization is requested,
4308 // which validates that the confirm pattern still triggers confirmation
4309 let _auth = rx.expect_authorization().await;
4310
4311 drop(_task);
4312 }
4313
4314 // Test 4: tool-specific default: deny is respected even with global default: allow
4315 {
4316 let environment = Rc::new(cx.update(|cx| {
4317 FakeThreadEnvironment::default()
4318 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4319 }));
4320
4321 cx.update(|cx| {
4322 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4323 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4324 settings.tool_permissions.tools.insert(
4325 TerminalTool::NAME.into(),
4326 agent_settings::ToolRules {
4327 default: Some(settings::ToolPermissionMode::Deny),
4328 always_allow: vec![],
4329 always_deny: vec![],
4330 always_confirm: vec![],
4331 invalid_patterns: vec![],
4332 },
4333 );
4334 agent_settings::AgentSettings::override_global(settings, cx);
4335 });
4336
4337 #[allow(clippy::arc_with_non_send_sync)]
4338 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4339 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4340
4341 let task = cx.update(|cx| {
4342 tool.run(
4343 ToolInput::resolved(crate::TerminalToolInput {
4344 command: "echo hello".to_string(),
4345 cd: ".".to_string(),
4346 timeout_ms: None,
4347 }),
4348 event_stream,
4349 cx,
4350 )
4351 });
4352
4353 // tool-specific default: deny is respected even with global default: allow
4354 let result = task.await;
4355 assert!(
4356 result.is_err(),
4357 "expected command to be blocked by tool-specific deny default"
4358 );
4359 let err_msg = result.unwrap_err().to_lowercase();
4360 assert!(
4361 err_msg.contains("disabled"),
4362 "error should mention the tool is disabled, got: {err_msg}"
4363 );
4364 }
4365}
4366
4367#[gpui::test]
4368async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
4369 init_test(cx);
4370 cx.update(|cx| {
4371 LanguageModelRegistry::test(cx);
4372 });
4373 cx.update(|cx| {
4374 cx.update_flags(true, vec!["subagents".to_string()]);
4375 });
4376
4377 let fs = FakeFs::new(cx.executor());
4378 fs.insert_tree(
4379 "/",
4380 json!({
4381 "a": {
4382 "b.md": "Lorem"
4383 }
4384 }),
4385 )
4386 .await;
4387 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4388 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4389 let agent = cx.update(|cx| {
4390 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
4391 });
4392 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4393
4394 let acp_thread = cx
4395 .update(|cx| {
4396 connection
4397 .clone()
4398 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
4399 })
4400 .await
4401 .unwrap();
4402 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4403 let thread = agent.read_with(cx, |agent, _| {
4404 agent.sessions.get(&session_id).unwrap().thread.clone()
4405 });
4406 let model = Arc::new(FakeLanguageModel::default());
4407
4408 // Ensure empty threads are not saved, even if they get mutated.
4409 thread.update(cx, |thread, cx| {
4410 thread.set_model(model.clone(), cx);
4411 });
4412 cx.run_until_parked();
4413
4414 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4415 cx.run_until_parked();
4416 model.send_last_completion_stream_text_chunk("spawning subagent");
4417 let subagent_tool_input = SpawnAgentToolInput {
4418 label: "label".to_string(),
4419 message: "subagent task prompt".to_string(),
4420 session_id: None,
4421 };
4422 let subagent_tool_use = LanguageModelToolUse {
4423 id: "subagent_1".into(),
4424 name: SpawnAgentTool::NAME.into(),
4425 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4426 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4427 is_input_complete: true,
4428 thought_signature: None,
4429 };
4430 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4431 subagent_tool_use,
4432 ));
4433 model.end_last_completion_stream();
4434
4435 cx.run_until_parked();
4436
4437 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4438 thread
4439 .running_subagent_ids(cx)
4440 .get(0)
4441 .expect("subagent thread should be running")
4442 .clone()
4443 });
4444
4445 let subagent_thread = agent.read_with(cx, |agent, _cx| {
4446 agent
4447 .sessions
4448 .get(&subagent_session_id)
4449 .expect("subagent session should exist")
4450 .acp_thread
4451 .clone()
4452 });
4453
4454 model.send_last_completion_stream_text_chunk("subagent task response");
4455 model.end_last_completion_stream();
4456
4457 cx.run_until_parked();
4458
4459 assert_eq!(
4460 subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4461 indoc! {"
4462 ## User
4463
4464 subagent task prompt
4465
4466 ## Assistant
4467
4468 subagent task response
4469
4470 "}
4471 );
4472
4473 model.send_last_completion_stream_text_chunk("Response");
4474 model.end_last_completion_stream();
4475
4476 send.await.unwrap();
4477
4478 assert_eq!(
4479 acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4480 indoc! {r#"
4481 ## User
4482
4483 Prompt
4484
4485 ## Assistant
4486
4487 spawning subagent
4488
4489 **Tool Call: label**
4490 Status: Completed
4491
4492 subagent task response
4493
4494 ## Assistant
4495
4496 Response
4497
4498 "#},
4499 );
4500}
4501
4502#[gpui::test]
4503async fn test_subagent_tool_output_does_not_include_thinking(cx: &mut TestAppContext) {
4504 init_test(cx);
4505 cx.update(|cx| {
4506 LanguageModelRegistry::test(cx);
4507 });
4508 cx.update(|cx| {
4509 cx.update_flags(true, vec!["subagents".to_string()]);
4510 });
4511
4512 let fs = FakeFs::new(cx.executor());
4513 fs.insert_tree(
4514 "/",
4515 json!({
4516 "a": {
4517 "b.md": "Lorem"
4518 }
4519 }),
4520 )
4521 .await;
4522 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4523 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4524 let agent = cx.update(|cx| {
4525 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
4526 });
4527 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4528
4529 let acp_thread = cx
4530 .update(|cx| {
4531 connection
4532 .clone()
4533 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
4534 })
4535 .await
4536 .unwrap();
4537 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4538 let thread = agent.read_with(cx, |agent, _| {
4539 agent.sessions.get(&session_id).unwrap().thread.clone()
4540 });
4541 let model = Arc::new(FakeLanguageModel::default());
4542
4543 // Ensure empty threads are not saved, even if they get mutated.
4544 thread.update(cx, |thread, cx| {
4545 thread.set_model(model.clone(), cx);
4546 });
4547 cx.run_until_parked();
4548
4549 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4550 cx.run_until_parked();
4551 model.send_last_completion_stream_text_chunk("spawning subagent");
4552 let subagent_tool_input = SpawnAgentToolInput {
4553 label: "label".to_string(),
4554 message: "subagent task prompt".to_string(),
4555 session_id: None,
4556 };
4557 let subagent_tool_use = LanguageModelToolUse {
4558 id: "subagent_1".into(),
4559 name: SpawnAgentTool::NAME.into(),
4560 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4561 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4562 is_input_complete: true,
4563 thought_signature: None,
4564 };
4565 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4566 subagent_tool_use,
4567 ));
4568 model.end_last_completion_stream();
4569
4570 cx.run_until_parked();
4571
4572 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4573 thread
4574 .running_subagent_ids(cx)
4575 .get(0)
4576 .expect("subagent thread should be running")
4577 .clone()
4578 });
4579
4580 let subagent_thread = agent.read_with(cx, |agent, _cx| {
4581 agent
4582 .sessions
4583 .get(&subagent_session_id)
4584 .expect("subagent session should exist")
4585 .acp_thread
4586 .clone()
4587 });
4588
4589 model.send_last_completion_stream_text_chunk("subagent task response 1");
4590 model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
4591 text: "thinking more about the subagent task".into(),
4592 signature: None,
4593 });
4594 model.send_last_completion_stream_text_chunk("subagent task response 2");
4595 model.end_last_completion_stream();
4596
4597 cx.run_until_parked();
4598
4599 assert_eq!(
4600 subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4601 indoc! {"
4602 ## User
4603
4604 subagent task prompt
4605
4606 ## Assistant
4607
4608 subagent task response 1
4609
4610 <thinking>
4611 thinking more about the subagent task
4612 </thinking>
4613
4614 subagent task response 2
4615
4616 "}
4617 );
4618
4619 model.send_last_completion_stream_text_chunk("Response");
4620 model.end_last_completion_stream();
4621
4622 send.await.unwrap();
4623
4624 assert_eq!(
4625 acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4626 indoc! {r#"
4627 ## User
4628
4629 Prompt
4630
4631 ## Assistant
4632
4633 spawning subagent
4634
4635 **Tool Call: label**
4636 Status: Completed
4637
4638 subagent task response 1
4639
4640 subagent task response 2
4641
4642 ## Assistant
4643
4644 Response
4645
4646 "#},
4647 );
4648}
4649
4650#[gpui::test]
4651async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAppContext) {
4652 init_test(cx);
4653 cx.update(|cx| {
4654 LanguageModelRegistry::test(cx);
4655 });
4656 cx.update(|cx| {
4657 cx.update_flags(true, vec!["subagents".to_string()]);
4658 });
4659
4660 let fs = FakeFs::new(cx.executor());
4661 fs.insert_tree(
4662 "/",
4663 json!({
4664 "a": {
4665 "b.md": "Lorem"
4666 }
4667 }),
4668 )
4669 .await;
4670 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4671 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4672 let agent = cx.update(|cx| {
4673 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
4674 });
4675 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4676
4677 let acp_thread = cx
4678 .update(|cx| {
4679 connection
4680 .clone()
4681 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
4682 })
4683 .await
4684 .unwrap();
4685 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4686 let thread = agent.read_with(cx, |agent, _| {
4687 agent.sessions.get(&session_id).unwrap().thread.clone()
4688 });
4689 let model = Arc::new(FakeLanguageModel::default());
4690
4691 // Ensure empty threads are not saved, even if they get mutated.
4692 thread.update(cx, |thread, cx| {
4693 thread.set_model(model.clone(), cx);
4694 });
4695 cx.run_until_parked();
4696
4697 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4698 cx.run_until_parked();
4699 model.send_last_completion_stream_text_chunk("spawning subagent");
4700 let subagent_tool_input = SpawnAgentToolInput {
4701 label: "label".to_string(),
4702 message: "subagent task prompt".to_string(),
4703 session_id: None,
4704 };
4705 let subagent_tool_use = LanguageModelToolUse {
4706 id: "subagent_1".into(),
4707 name: SpawnAgentTool::NAME.into(),
4708 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4709 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4710 is_input_complete: true,
4711 thought_signature: None,
4712 };
4713 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4714 subagent_tool_use,
4715 ));
4716 model.end_last_completion_stream();
4717
4718 cx.run_until_parked();
4719
4720 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4721 thread
4722 .running_subagent_ids(cx)
4723 .get(0)
4724 .expect("subagent thread should be running")
4725 .clone()
4726 });
4727 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4728 agent
4729 .sessions
4730 .get(&subagent_session_id)
4731 .expect("subagent session should exist")
4732 .acp_thread
4733 .clone()
4734 });
4735
4736 // model.send_last_completion_stream_text_chunk("subagent task response");
4737 // model.end_last_completion_stream();
4738
4739 // cx.run_until_parked();
4740
4741 acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
4742
4743 cx.run_until_parked();
4744
4745 send.await.unwrap();
4746
4747 acp_thread.read_with(cx, |thread, cx| {
4748 assert_eq!(thread.status(), ThreadStatus::Idle);
4749 assert_eq!(
4750 thread.to_markdown(cx),
4751 indoc! {"
4752 ## User
4753
4754 Prompt
4755
4756 ## Assistant
4757
4758 spawning subagent
4759
4760 **Tool Call: label**
4761 Status: Canceled
4762
4763 "}
4764 );
4765 });
4766 subagent_acp_thread.read_with(cx, |thread, cx| {
4767 assert_eq!(thread.status(), ThreadStatus::Idle);
4768 assert_eq!(
4769 thread.to_markdown(cx),
4770 indoc! {"
4771 ## User
4772
4773 subagent task prompt
4774
4775 "}
4776 );
4777 });
4778}
4779
4780#[gpui::test]
4781async fn test_subagent_tool_resume_session(cx: &mut TestAppContext) {
4782 init_test(cx);
4783 cx.update(|cx| {
4784 LanguageModelRegistry::test(cx);
4785 });
4786 cx.update(|cx| {
4787 cx.update_flags(true, vec!["subagents".to_string()]);
4788 });
4789
4790 let fs = FakeFs::new(cx.executor());
4791 fs.insert_tree(
4792 "/",
4793 json!({
4794 "a": {
4795 "b.md": "Lorem"
4796 }
4797 }),
4798 )
4799 .await;
4800 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4801 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4802 let agent = cx.update(|cx| {
4803 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
4804 });
4805 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4806
4807 let acp_thread = cx
4808 .update(|cx| {
4809 connection
4810 .clone()
4811 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
4812 })
4813 .await
4814 .unwrap();
4815 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4816 let thread = agent.read_with(cx, |agent, _| {
4817 agent.sessions.get(&session_id).unwrap().thread.clone()
4818 });
4819 let model = Arc::new(FakeLanguageModel::default());
4820
4821 thread.update(cx, |thread, cx| {
4822 thread.set_model(model.clone(), cx);
4823 });
4824 cx.run_until_parked();
4825
4826 // === First turn: create subagent ===
4827 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx));
4828 cx.run_until_parked();
4829 model.send_last_completion_stream_text_chunk("spawning subagent");
4830 let subagent_tool_input = SpawnAgentToolInput {
4831 label: "initial task".to_string(),
4832 message: "do the first task".to_string(),
4833 session_id: None,
4834 };
4835 let subagent_tool_use = LanguageModelToolUse {
4836 id: "subagent_1".into(),
4837 name: SpawnAgentTool::NAME.into(),
4838 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4839 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4840 is_input_complete: true,
4841 thought_signature: None,
4842 };
4843 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4844 subagent_tool_use,
4845 ));
4846 model.end_last_completion_stream();
4847
4848 cx.run_until_parked();
4849
4850 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4851 thread
4852 .running_subagent_ids(cx)
4853 .get(0)
4854 .expect("subagent thread should be running")
4855 .clone()
4856 });
4857
4858 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4859 agent
4860 .sessions
4861 .get(&subagent_session_id)
4862 .expect("subagent session should exist")
4863 .acp_thread
4864 .clone()
4865 });
4866
4867 // Subagent responds
4868 model.send_last_completion_stream_text_chunk("first task response");
4869 model.end_last_completion_stream();
4870
4871 cx.run_until_parked();
4872
4873 // Parent model responds to complete first turn
4874 model.send_last_completion_stream_text_chunk("First response");
4875 model.end_last_completion_stream();
4876
4877 send.await.unwrap();
4878
4879 // Verify subagent is no longer running
4880 thread.read_with(cx, |thread, cx| {
4881 assert!(
4882 thread.running_subagent_ids(cx).is_empty(),
4883 "subagent should not be running after completion"
4884 );
4885 });
4886
4887 // === Second turn: resume subagent with session_id ===
4888 let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx));
4889 cx.run_until_parked();
4890 model.send_last_completion_stream_text_chunk("resuming subagent");
4891 let resume_tool_input = SpawnAgentToolInput {
4892 label: "follow-up task".to_string(),
4893 message: "do the follow-up task".to_string(),
4894 session_id: Some(subagent_session_id.clone()),
4895 };
4896 let resume_tool_use = LanguageModelToolUse {
4897 id: "subagent_2".into(),
4898 name: SpawnAgentTool::NAME.into(),
4899 raw_input: serde_json::to_string(&resume_tool_input).unwrap(),
4900 input: serde_json::to_value(&resume_tool_input).unwrap(),
4901 is_input_complete: true,
4902 thought_signature: None,
4903 };
4904 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use));
4905 model.end_last_completion_stream();
4906
4907 cx.run_until_parked();
4908
4909 // Subagent should be running again with the same session
4910 thread.read_with(cx, |thread, cx| {
4911 let running = thread.running_subagent_ids(cx);
4912 assert_eq!(running.len(), 1, "subagent should be running");
4913 assert_eq!(running[0], subagent_session_id, "should be same session");
4914 });
4915
4916 // Subagent responds to follow-up
4917 model.send_last_completion_stream_text_chunk("follow-up task response");
4918 model.end_last_completion_stream();
4919
4920 cx.run_until_parked();
4921
4922 // Parent model responds to complete second turn
4923 model.send_last_completion_stream_text_chunk("Second response");
4924 model.end_last_completion_stream();
4925
4926 send2.await.unwrap();
4927
4928 // Verify subagent is no longer running
4929 thread.read_with(cx, |thread, cx| {
4930 assert!(
4931 thread.running_subagent_ids(cx).is_empty(),
4932 "subagent should not be running after resume completion"
4933 );
4934 });
4935
4936 // Verify the subagent's acp thread has both conversation turns
4937 assert_eq!(
4938 subagent_acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4939 indoc! {"
4940 ## User
4941
4942 do the first task
4943
4944 ## Assistant
4945
4946 first task response
4947
4948 ## User
4949
4950 do the follow-up task
4951
4952 ## Assistant
4953
4954 follow-up task response
4955
4956 "}
4957 );
4958}
4959
4960#[gpui::test]
4961async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
4962 init_test(cx);
4963
4964 cx.update(|cx| {
4965 cx.update_flags(true, vec!["subagents".to_string()]);
4966 });
4967
4968 let fs = FakeFs::new(cx.executor());
4969 fs.insert_tree(path!("/test"), json!({})).await;
4970 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4971 let project_context = cx.new(|_cx| ProjectContext::default());
4972 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4973 let context_server_registry =
4974 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4975 let model = Arc::new(FakeLanguageModel::default());
4976
4977 let environment = Rc::new(cx.update(|cx| {
4978 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4979 }));
4980
4981 let thread = cx.new(|cx| {
4982 let mut thread = Thread::new(
4983 project.clone(),
4984 project_context,
4985 context_server_registry,
4986 Templates::new(),
4987 Some(model),
4988 cx,
4989 );
4990 thread.add_default_tools(environment, cx);
4991 thread
4992 });
4993
4994 thread.read_with(cx, |thread, _| {
4995 assert!(
4996 thread.has_registered_tool(SpawnAgentTool::NAME),
4997 "subagent tool should be present when feature flag is enabled"
4998 );
4999 });
5000}
5001
5002#[gpui::test]
5003async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
5004 init_test(cx);
5005
5006 cx.update(|cx| {
5007 cx.update_flags(true, vec!["subagents".to_string()]);
5008 });
5009
5010 let fs = FakeFs::new(cx.executor());
5011 fs.insert_tree(path!("/test"), json!({})).await;
5012 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
5013 let project_context = cx.new(|_cx| ProjectContext::default());
5014 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5015 let context_server_registry =
5016 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5017 let model = Arc::new(FakeLanguageModel::default());
5018
5019 let parent_thread = cx.new(|cx| {
5020 Thread::new(
5021 project.clone(),
5022 project_context,
5023 context_server_registry,
5024 Templates::new(),
5025 Some(model.clone()),
5026 cx,
5027 )
5028 });
5029
5030 let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
5031 subagent_thread.read_with(cx, |subagent_thread, cx| {
5032 assert!(subagent_thread.is_subagent());
5033 assert_eq!(subagent_thread.depth(), 1);
5034 assert_eq!(
5035 subagent_thread.model().map(|model| model.id()),
5036 Some(model.id())
5037 );
5038 assert_eq!(
5039 subagent_thread.parent_thread_id(),
5040 Some(parent_thread.read(cx).id().clone())
5041 );
5042 });
5043}
5044
5045#[gpui::test]
5046async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
5047 init_test(cx);
5048
5049 cx.update(|cx| {
5050 cx.update_flags(true, vec!["subagents".to_string()]);
5051 });
5052
5053 let fs = FakeFs::new(cx.executor());
5054 fs.insert_tree(path!("/test"), json!({})).await;
5055 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
5056 let project_context = cx.new(|_cx| ProjectContext::default());
5057 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5058 let context_server_registry =
5059 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5060 let model = Arc::new(FakeLanguageModel::default());
5061 let environment = Rc::new(cx.update(|cx| {
5062 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
5063 }));
5064
5065 let deep_parent_thread = cx.new(|cx| {
5066 let mut thread = Thread::new(
5067 project.clone(),
5068 project_context,
5069 context_server_registry,
5070 Templates::new(),
5071 Some(model.clone()),
5072 cx,
5073 );
5074 thread.set_subagent_context(SubagentContext {
5075 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
5076 depth: MAX_SUBAGENT_DEPTH - 1,
5077 });
5078 thread
5079 });
5080 let deep_subagent_thread = cx.new(|cx| {
5081 let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
5082 thread.add_default_tools(environment, cx);
5083 thread
5084 });
5085
5086 deep_subagent_thread.read_with(cx, |thread, _| {
5087 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
5088 assert!(
5089 !thread.has_registered_tool(SpawnAgentTool::NAME),
5090 "subagent tool should not be present at max depth"
5091 );
5092 });
5093}
5094
5095#[gpui::test]
5096async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
5097 init_test(cx);
5098
5099 cx.update(|cx| {
5100 cx.update_flags(true, vec!["subagents".to_string()]);
5101 });
5102
5103 let fs = FakeFs::new(cx.executor());
5104 fs.insert_tree(path!("/test"), json!({})).await;
5105 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
5106 let project_context = cx.new(|_cx| ProjectContext::default());
5107 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5108 let context_server_registry =
5109 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5110 let model = Arc::new(FakeLanguageModel::default());
5111
5112 let parent = cx.new(|cx| {
5113 Thread::new(
5114 project.clone(),
5115 project_context.clone(),
5116 context_server_registry.clone(),
5117 Templates::new(),
5118 Some(model.clone()),
5119 cx,
5120 )
5121 });
5122
5123 let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
5124
5125 parent.update(cx, |thread, _cx| {
5126 thread.register_running_subagent(subagent.downgrade());
5127 });
5128
5129 subagent
5130 .update(cx, |thread, cx| {
5131 thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
5132 })
5133 .unwrap();
5134 cx.run_until_parked();
5135
5136 subagent.read_with(cx, |thread, _| {
5137 assert!(!thread.is_turn_complete(), "subagent should be running");
5138 });
5139
5140 parent.update(cx, |thread, cx| {
5141 thread.cancel(cx).detach();
5142 });
5143
5144 subagent.read_with(cx, |thread, _| {
5145 assert!(
5146 thread.is_turn_complete(),
5147 "subagent should be cancelled when parent cancels"
5148 );
5149 });
5150}
5151
5152#[gpui::test]
5153async fn test_subagent_context_window_warning(cx: &mut TestAppContext) {
5154 init_test(cx);
5155 cx.update(|cx| {
5156 LanguageModelRegistry::test(cx);
5157 });
5158 cx.update(|cx| {
5159 cx.update_flags(true, vec!["subagents".to_string()]);
5160 });
5161
5162 let fs = FakeFs::new(cx.executor());
5163 fs.insert_tree(
5164 "/",
5165 json!({
5166 "a": {
5167 "b.md": "Lorem"
5168 }
5169 }),
5170 )
5171 .await;
5172 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
5173 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5174 let agent = cx.update(|cx| {
5175 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
5176 });
5177 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5178
5179 let acp_thread = cx
5180 .update(|cx| {
5181 connection
5182 .clone()
5183 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
5184 })
5185 .await
5186 .unwrap();
5187 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5188 let thread = agent.read_with(cx, |agent, _| {
5189 agent.sessions.get(&session_id).unwrap().thread.clone()
5190 });
5191 let model = Arc::new(FakeLanguageModel::default());
5192
5193 thread.update(cx, |thread, cx| {
5194 thread.set_model(model.clone(), cx);
5195 });
5196 cx.run_until_parked();
5197
5198 // Start the parent turn
5199 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
5200 cx.run_until_parked();
5201 model.send_last_completion_stream_text_chunk("spawning subagent");
5202 let subagent_tool_input = SpawnAgentToolInput {
5203 label: "label".to_string(),
5204 message: "subagent task prompt".to_string(),
5205 session_id: None,
5206 };
5207 let subagent_tool_use = LanguageModelToolUse {
5208 id: "subagent_1".into(),
5209 name: SpawnAgentTool::NAME.into(),
5210 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5211 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5212 is_input_complete: true,
5213 thought_signature: None,
5214 };
5215 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5216 subagent_tool_use,
5217 ));
5218 model.end_last_completion_stream();
5219
5220 cx.run_until_parked();
5221
5222 // Verify subagent is running
5223 let subagent_session_id = thread.read_with(cx, |thread, cx| {
5224 thread
5225 .running_subagent_ids(cx)
5226 .get(0)
5227 .expect("subagent thread should be running")
5228 .clone()
5229 });
5230
5231 // Send a usage update that crosses the warning threshold (80% of 1,000,000)
5232 model.send_last_completion_stream_text_chunk("partial work");
5233 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5234 TokenUsage {
5235 input_tokens: 850_000,
5236 output_tokens: 0,
5237 cache_creation_input_tokens: 0,
5238 cache_read_input_tokens: 0,
5239 },
5240 ));
5241
5242 cx.run_until_parked();
5243
5244 // The subagent should no longer be running
5245 thread.read_with(cx, |thread, cx| {
5246 assert!(
5247 thread.running_subagent_ids(cx).is_empty(),
5248 "subagent should be stopped after context window warning"
5249 );
5250 });
5251
5252 // The parent model should get a new completion request to respond to the tool error
5253 model.send_last_completion_stream_text_chunk("Response after warning");
5254 model.end_last_completion_stream();
5255
5256 send.await.unwrap();
5257
5258 // Verify the parent thread shows the warning error in the tool call
5259 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5260 assert!(
5261 markdown.contains("nearing the end of its context window"),
5262 "tool output should contain context window warning message, got:\n{markdown}"
5263 );
5264 assert!(
5265 markdown.contains("Status: Failed"),
5266 "tool call should have Failed status, got:\n{markdown}"
5267 );
5268
5269 // Verify the subagent session still exists (can be resumed)
5270 agent.read_with(cx, |agent, _cx| {
5271 assert!(
5272 agent.sessions.contains_key(&subagent_session_id),
5273 "subagent session should still exist for potential resume"
5274 );
5275 });
5276}
5277
5278#[gpui::test]
5279async fn test_subagent_no_context_window_warning_when_already_at_warning(cx: &mut TestAppContext) {
5280 init_test(cx);
5281 cx.update(|cx| {
5282 LanguageModelRegistry::test(cx);
5283 });
5284 cx.update(|cx| {
5285 cx.update_flags(true, vec!["subagents".to_string()]);
5286 });
5287
5288 let fs = FakeFs::new(cx.executor());
5289 fs.insert_tree(
5290 "/",
5291 json!({
5292 "a": {
5293 "b.md": "Lorem"
5294 }
5295 }),
5296 )
5297 .await;
5298 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
5299 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5300 let agent = cx.update(|cx| {
5301 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
5302 });
5303 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5304
5305 let acp_thread = cx
5306 .update(|cx| {
5307 connection
5308 .clone()
5309 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
5310 })
5311 .await
5312 .unwrap();
5313 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5314 let thread = agent.read_with(cx, |agent, _| {
5315 agent.sessions.get(&session_id).unwrap().thread.clone()
5316 });
5317 let model = Arc::new(FakeLanguageModel::default());
5318
5319 thread.update(cx, |thread, cx| {
5320 thread.set_model(model.clone(), cx);
5321 });
5322 cx.run_until_parked();
5323
5324 // === First turn: create subagent, trigger context window warning ===
5325 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx));
5326 cx.run_until_parked();
5327 model.send_last_completion_stream_text_chunk("spawning subagent");
5328 let subagent_tool_input = SpawnAgentToolInput {
5329 label: "initial task".to_string(),
5330 message: "do the first task".to_string(),
5331 session_id: None,
5332 };
5333 let subagent_tool_use = LanguageModelToolUse {
5334 id: "subagent_1".into(),
5335 name: SpawnAgentTool::NAME.into(),
5336 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5337 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5338 is_input_complete: true,
5339 thought_signature: None,
5340 };
5341 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5342 subagent_tool_use,
5343 ));
5344 model.end_last_completion_stream();
5345
5346 cx.run_until_parked();
5347
5348 let subagent_session_id = thread.read_with(cx, |thread, cx| {
5349 thread
5350 .running_subagent_ids(cx)
5351 .get(0)
5352 .expect("subagent thread should be running")
5353 .clone()
5354 });
5355
5356 // Subagent sends a usage update that crosses the warning threshold.
5357 // This triggers Normal→Warning, stopping the subagent.
5358 model.send_last_completion_stream_text_chunk("partial work");
5359 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5360 TokenUsage {
5361 input_tokens: 850_000,
5362 output_tokens: 0,
5363 cache_creation_input_tokens: 0,
5364 cache_read_input_tokens: 0,
5365 },
5366 ));
5367
5368 cx.run_until_parked();
5369
5370 // Verify the first turn was stopped with a context window warning
5371 thread.read_with(cx, |thread, cx| {
5372 assert!(
5373 thread.running_subagent_ids(cx).is_empty(),
5374 "subagent should be stopped after context window warning"
5375 );
5376 });
5377
5378 // Parent model responds to complete first turn
5379 model.send_last_completion_stream_text_chunk("First response");
5380 model.end_last_completion_stream();
5381
5382 send.await.unwrap();
5383
5384 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5385 assert!(
5386 markdown.contains("nearing the end of its context window"),
5387 "first turn should have context window warning, got:\n{markdown}"
5388 );
5389
5390 // === Second turn: resume the same subagent (now at Warning level) ===
5391 let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx));
5392 cx.run_until_parked();
5393 model.send_last_completion_stream_text_chunk("resuming subagent");
5394 let resume_tool_input = SpawnAgentToolInput {
5395 label: "follow-up task".to_string(),
5396 message: "do the follow-up task".to_string(),
5397 session_id: Some(subagent_session_id.clone()),
5398 };
5399 let resume_tool_use = LanguageModelToolUse {
5400 id: "subagent_2".into(),
5401 name: SpawnAgentTool::NAME.into(),
5402 raw_input: serde_json::to_string(&resume_tool_input).unwrap(),
5403 input: serde_json::to_value(&resume_tool_input).unwrap(),
5404 is_input_complete: true,
5405 thought_signature: None,
5406 };
5407 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use));
5408 model.end_last_completion_stream();
5409
5410 cx.run_until_parked();
5411
5412 // Subagent responds with tokens still at warning level (no worse).
5413 // Since ratio_before_prompt was already Warning, this should NOT
5414 // trigger the context window warning again.
5415 model.send_last_completion_stream_text_chunk("follow-up task response");
5416 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5417 TokenUsage {
5418 input_tokens: 870_000,
5419 output_tokens: 0,
5420 cache_creation_input_tokens: 0,
5421 cache_read_input_tokens: 0,
5422 },
5423 ));
5424 model.end_last_completion_stream();
5425
5426 cx.run_until_parked();
5427
5428 // Parent model responds to complete second turn
5429 model.send_last_completion_stream_text_chunk("Second response");
5430 model.end_last_completion_stream();
5431
5432 send2.await.unwrap();
5433
5434 // The resumed subagent should have completed normally since the ratio
5435 // didn't transition (it was Warning before and stayed at Warning)
5436 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5437 assert!(
5438 markdown.contains("follow-up task response"),
5439 "resumed subagent should complete normally when already at warning, got:\n{markdown}"
5440 );
5441 // The second tool call should NOT have a context window warning
5442 let second_tool_pos = markdown
5443 .find("follow-up task")
5444 .expect("should find follow-up tool call");
5445 let after_second_tool = &markdown[second_tool_pos..];
5446 assert!(
5447 !after_second_tool.contains("nearing the end of its context window"),
5448 "should NOT contain context window warning for resumed subagent at same level, got:\n{after_second_tool}"
5449 );
5450}
5451
5452#[gpui::test]
5453async fn test_subagent_error_propagation(cx: &mut TestAppContext) {
5454 init_test(cx);
5455 cx.update(|cx| {
5456 LanguageModelRegistry::test(cx);
5457 });
5458 cx.update(|cx| {
5459 cx.update_flags(true, vec!["subagents".to_string()]);
5460 });
5461
5462 let fs = FakeFs::new(cx.executor());
5463 fs.insert_tree(
5464 "/",
5465 json!({
5466 "a": {
5467 "b.md": "Lorem"
5468 }
5469 }),
5470 )
5471 .await;
5472 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
5473 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5474 let agent = cx.update(|cx| {
5475 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
5476 });
5477 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5478
5479 let acp_thread = cx
5480 .update(|cx| {
5481 connection
5482 .clone()
5483 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
5484 })
5485 .await
5486 .unwrap();
5487 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5488 let thread = agent.read_with(cx, |agent, _| {
5489 agent.sessions.get(&session_id).unwrap().thread.clone()
5490 });
5491 let model = Arc::new(FakeLanguageModel::default());
5492
5493 thread.update(cx, |thread, cx| {
5494 thread.set_model(model.clone(), cx);
5495 });
5496 cx.run_until_parked();
5497
5498 // Start the parent turn
5499 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
5500 cx.run_until_parked();
5501 model.send_last_completion_stream_text_chunk("spawning subagent");
5502 let subagent_tool_input = SpawnAgentToolInput {
5503 label: "label".to_string(),
5504 message: "subagent task prompt".to_string(),
5505 session_id: None,
5506 };
5507 let subagent_tool_use = LanguageModelToolUse {
5508 id: "subagent_1".into(),
5509 name: SpawnAgentTool::NAME.into(),
5510 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5511 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5512 is_input_complete: true,
5513 thought_signature: None,
5514 };
5515 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5516 subagent_tool_use,
5517 ));
5518 model.end_last_completion_stream();
5519
5520 cx.run_until_parked();
5521
5522 // Verify subagent is running
5523 thread.read_with(cx, |thread, cx| {
5524 assert!(
5525 !thread.running_subagent_ids(cx).is_empty(),
5526 "subagent should be running"
5527 );
5528 });
5529
5530 // The subagent's model returns a non-retryable error
5531 model.send_last_completion_stream_error(LanguageModelCompletionError::PromptTooLarge {
5532 tokens: None,
5533 });
5534
5535 cx.run_until_parked();
5536
5537 // The subagent should no longer be running
5538 thread.read_with(cx, |thread, cx| {
5539 assert!(
5540 thread.running_subagent_ids(cx).is_empty(),
5541 "subagent should not be running after error"
5542 );
5543 });
5544
5545 // The parent model should get a new completion request to respond to the tool error
5546 model.send_last_completion_stream_text_chunk("Response after error");
5547 model.end_last_completion_stream();
5548
5549 send.await.unwrap();
5550
5551 // Verify the parent thread shows the error in the tool call
5552 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5553 assert!(
5554 markdown.contains("Status: Failed"),
5555 "tool call should have Failed status after model error, got:\n{markdown}"
5556 );
5557}
5558
5559#[gpui::test]
5560async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5561 init_test(cx);
5562
5563 let fs = FakeFs::new(cx.executor());
5564 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5565 .await;
5566 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5567
5568 cx.update(|cx| {
5569 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5570 settings.tool_permissions.tools.insert(
5571 EditFileTool::NAME.into(),
5572 agent_settings::ToolRules {
5573 default: Some(settings::ToolPermissionMode::Allow),
5574 always_allow: vec![],
5575 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5576 always_confirm: vec![],
5577 invalid_patterns: vec![],
5578 },
5579 );
5580 agent_settings::AgentSettings::override_global(settings, cx);
5581 });
5582
5583 let context_server_registry =
5584 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5585 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5586 let templates = crate::Templates::new();
5587 let thread = cx.new(|cx| {
5588 crate::Thread::new(
5589 project.clone(),
5590 cx.new(|_cx| prompt_store::ProjectContext::default()),
5591 context_server_registry,
5592 templates.clone(),
5593 None,
5594 cx,
5595 )
5596 });
5597
5598 #[allow(clippy::arc_with_non_send_sync)]
5599 let tool = Arc::new(crate::EditFileTool::new(
5600 project.clone(),
5601 thread.downgrade(),
5602 language_registry,
5603 templates,
5604 ));
5605 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5606
5607 let task = cx.update(|cx| {
5608 tool.run(
5609 ToolInput::resolved(crate::EditFileToolInput {
5610 display_description: "Edit sensitive file".to_string(),
5611 path: "root/sensitive_config.txt".into(),
5612 mode: crate::EditFileMode::Edit,
5613 }),
5614 event_stream,
5615 cx,
5616 )
5617 });
5618
5619 let result = task.await;
5620 assert!(result.is_err(), "expected edit to be blocked");
5621 assert!(
5622 result.unwrap_err().to_string().contains("blocked"),
5623 "error should mention the edit was blocked"
5624 );
5625}
5626
5627#[gpui::test]
5628async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5629 init_test(cx);
5630
5631 let fs = FakeFs::new(cx.executor());
5632 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5633 .await;
5634 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5635
5636 cx.update(|cx| {
5637 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5638 settings.tool_permissions.tools.insert(
5639 DeletePathTool::NAME.into(),
5640 agent_settings::ToolRules {
5641 default: Some(settings::ToolPermissionMode::Allow),
5642 always_allow: vec![],
5643 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5644 always_confirm: vec![],
5645 invalid_patterns: vec![],
5646 },
5647 );
5648 agent_settings::AgentSettings::override_global(settings, cx);
5649 });
5650
5651 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5652
5653 #[allow(clippy::arc_with_non_send_sync)]
5654 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5655 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5656
5657 let task = cx.update(|cx| {
5658 tool.run(
5659 ToolInput::resolved(crate::DeletePathToolInput {
5660 path: "root/important_data.txt".to_string(),
5661 }),
5662 event_stream,
5663 cx,
5664 )
5665 });
5666
5667 let result = task.await;
5668 assert!(result.is_err(), "expected deletion to be blocked");
5669 assert!(
5670 result.unwrap_err().contains("blocked"),
5671 "error should mention the deletion was blocked"
5672 );
5673}
5674
5675#[gpui::test]
5676async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5677 init_test(cx);
5678
5679 let fs = FakeFs::new(cx.executor());
5680 fs.insert_tree(
5681 "/root",
5682 json!({
5683 "safe.txt": "content",
5684 "protected": {}
5685 }),
5686 )
5687 .await;
5688 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5689
5690 cx.update(|cx| {
5691 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5692 settings.tool_permissions.tools.insert(
5693 MovePathTool::NAME.into(),
5694 agent_settings::ToolRules {
5695 default: Some(settings::ToolPermissionMode::Allow),
5696 always_allow: vec![],
5697 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5698 always_confirm: vec![],
5699 invalid_patterns: vec![],
5700 },
5701 );
5702 agent_settings::AgentSettings::override_global(settings, cx);
5703 });
5704
5705 #[allow(clippy::arc_with_non_send_sync)]
5706 let tool = Arc::new(crate::MovePathTool::new(project));
5707 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5708
5709 let task = cx.update(|cx| {
5710 tool.run(
5711 ToolInput::resolved(crate::MovePathToolInput {
5712 source_path: "root/safe.txt".to_string(),
5713 destination_path: "root/protected/safe.txt".to_string(),
5714 }),
5715 event_stream,
5716 cx,
5717 )
5718 });
5719
5720 let result = task.await;
5721 assert!(
5722 result.is_err(),
5723 "expected move to be blocked due to destination path"
5724 );
5725 assert!(
5726 result.unwrap_err().contains("blocked"),
5727 "error should mention the move was blocked"
5728 );
5729}
5730
5731#[gpui::test]
5732async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5733 init_test(cx);
5734
5735 let fs = FakeFs::new(cx.executor());
5736 fs.insert_tree(
5737 "/root",
5738 json!({
5739 "secret.txt": "secret content",
5740 "public": {}
5741 }),
5742 )
5743 .await;
5744 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5745
5746 cx.update(|cx| {
5747 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5748 settings.tool_permissions.tools.insert(
5749 MovePathTool::NAME.into(),
5750 agent_settings::ToolRules {
5751 default: Some(settings::ToolPermissionMode::Allow),
5752 always_allow: vec![],
5753 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5754 always_confirm: vec![],
5755 invalid_patterns: vec![],
5756 },
5757 );
5758 agent_settings::AgentSettings::override_global(settings, cx);
5759 });
5760
5761 #[allow(clippy::arc_with_non_send_sync)]
5762 let tool = Arc::new(crate::MovePathTool::new(project));
5763 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5764
5765 let task = cx.update(|cx| {
5766 tool.run(
5767 ToolInput::resolved(crate::MovePathToolInput {
5768 source_path: "root/secret.txt".to_string(),
5769 destination_path: "root/public/not_secret.txt".to_string(),
5770 }),
5771 event_stream,
5772 cx,
5773 )
5774 });
5775
5776 let result = task.await;
5777 assert!(
5778 result.is_err(),
5779 "expected move to be blocked due to source path"
5780 );
5781 assert!(
5782 result.unwrap_err().contains("blocked"),
5783 "error should mention the move was blocked"
5784 );
5785}
5786
5787#[gpui::test]
5788async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5789 init_test(cx);
5790
5791 let fs = FakeFs::new(cx.executor());
5792 fs.insert_tree(
5793 "/root",
5794 json!({
5795 "confidential.txt": "confidential data",
5796 "dest": {}
5797 }),
5798 )
5799 .await;
5800 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5801
5802 cx.update(|cx| {
5803 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5804 settings.tool_permissions.tools.insert(
5805 CopyPathTool::NAME.into(),
5806 agent_settings::ToolRules {
5807 default: Some(settings::ToolPermissionMode::Allow),
5808 always_allow: vec![],
5809 always_deny: vec![
5810 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5811 ],
5812 always_confirm: vec![],
5813 invalid_patterns: vec![],
5814 },
5815 );
5816 agent_settings::AgentSettings::override_global(settings, cx);
5817 });
5818
5819 #[allow(clippy::arc_with_non_send_sync)]
5820 let tool = Arc::new(crate::CopyPathTool::new(project));
5821 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5822
5823 let task = cx.update(|cx| {
5824 tool.run(
5825 ToolInput::resolved(crate::CopyPathToolInput {
5826 source_path: "root/confidential.txt".to_string(),
5827 destination_path: "root/dest/copy.txt".to_string(),
5828 }),
5829 event_stream,
5830 cx,
5831 )
5832 });
5833
5834 let result = task.await;
5835 assert!(result.is_err(), "expected copy to be blocked");
5836 assert!(
5837 result.unwrap_err().contains("blocked"),
5838 "error should mention the copy was blocked"
5839 );
5840}
5841
5842#[gpui::test]
5843async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5844 init_test(cx);
5845
5846 let fs = FakeFs::new(cx.executor());
5847 fs.insert_tree(
5848 "/root",
5849 json!({
5850 "normal.txt": "normal content",
5851 "readonly": {
5852 "config.txt": "readonly content"
5853 }
5854 }),
5855 )
5856 .await;
5857 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5858
5859 cx.update(|cx| {
5860 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5861 settings.tool_permissions.tools.insert(
5862 SaveFileTool::NAME.into(),
5863 agent_settings::ToolRules {
5864 default: Some(settings::ToolPermissionMode::Allow),
5865 always_allow: vec![],
5866 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5867 always_confirm: vec![],
5868 invalid_patterns: vec![],
5869 },
5870 );
5871 agent_settings::AgentSettings::override_global(settings, cx);
5872 });
5873
5874 #[allow(clippy::arc_with_non_send_sync)]
5875 let tool = Arc::new(crate::SaveFileTool::new(project));
5876 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5877
5878 let task = cx.update(|cx| {
5879 tool.run(
5880 ToolInput::resolved(crate::SaveFileToolInput {
5881 paths: vec![
5882 std::path::PathBuf::from("root/normal.txt"),
5883 std::path::PathBuf::from("root/readonly/config.txt"),
5884 ],
5885 }),
5886 event_stream,
5887 cx,
5888 )
5889 });
5890
5891 let result = task.await;
5892 assert!(
5893 result.is_err(),
5894 "expected save to be blocked due to denied path"
5895 );
5896 assert!(
5897 result.unwrap_err().contains("blocked"),
5898 "error should mention the save was blocked"
5899 );
5900}
5901
5902#[gpui::test]
5903async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5904 init_test(cx);
5905
5906 let fs = FakeFs::new(cx.executor());
5907 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5908 .await;
5909 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5910
5911 cx.update(|cx| {
5912 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5913 settings.tool_permissions.tools.insert(
5914 SaveFileTool::NAME.into(),
5915 agent_settings::ToolRules {
5916 default: Some(settings::ToolPermissionMode::Allow),
5917 always_allow: vec![],
5918 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5919 always_confirm: vec![],
5920 invalid_patterns: vec![],
5921 },
5922 );
5923 agent_settings::AgentSettings::override_global(settings, cx);
5924 });
5925
5926 #[allow(clippy::arc_with_non_send_sync)]
5927 let tool = Arc::new(crate::SaveFileTool::new(project));
5928 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5929
5930 let task = cx.update(|cx| {
5931 tool.run(
5932 ToolInput::resolved(crate::SaveFileToolInput {
5933 paths: vec![std::path::PathBuf::from("root/config.secret")],
5934 }),
5935 event_stream,
5936 cx,
5937 )
5938 });
5939
5940 let result = task.await;
5941 assert!(result.is_err(), "expected save to be blocked");
5942 assert!(
5943 result.unwrap_err().contains("blocked"),
5944 "error should mention the save was blocked"
5945 );
5946}
5947
5948#[gpui::test]
5949async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5950 init_test(cx);
5951
5952 cx.update(|cx| {
5953 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5954 settings.tool_permissions.tools.insert(
5955 WebSearchTool::NAME.into(),
5956 agent_settings::ToolRules {
5957 default: Some(settings::ToolPermissionMode::Allow),
5958 always_allow: vec![],
5959 always_deny: vec![
5960 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5961 ],
5962 always_confirm: vec![],
5963 invalid_patterns: vec![],
5964 },
5965 );
5966 agent_settings::AgentSettings::override_global(settings, cx);
5967 });
5968
5969 #[allow(clippy::arc_with_non_send_sync)]
5970 let tool = Arc::new(crate::WebSearchTool);
5971 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5972
5973 let input: crate::WebSearchToolInput =
5974 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5975
5976 let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
5977
5978 let result = task.await;
5979 assert!(result.is_err(), "expected search to be blocked");
5980 match result.unwrap_err() {
5981 crate::WebSearchToolOutput::Error { error } => {
5982 assert!(
5983 error.contains("blocked"),
5984 "error should mention the search was blocked"
5985 );
5986 }
5987 other => panic!("expected Error variant, got: {other:?}"),
5988 }
5989}
5990
5991#[gpui::test]
5992async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5993 init_test(cx);
5994
5995 let fs = FakeFs::new(cx.executor());
5996 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5997 .await;
5998 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5999
6000 cx.update(|cx| {
6001 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6002 settings.tool_permissions.tools.insert(
6003 EditFileTool::NAME.into(),
6004 agent_settings::ToolRules {
6005 default: Some(settings::ToolPermissionMode::Confirm),
6006 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
6007 always_deny: vec![],
6008 always_confirm: vec![],
6009 invalid_patterns: vec![],
6010 },
6011 );
6012 agent_settings::AgentSettings::override_global(settings, cx);
6013 });
6014
6015 let context_server_registry =
6016 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
6017 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
6018 let templates = crate::Templates::new();
6019 let thread = cx.new(|cx| {
6020 crate::Thread::new(
6021 project.clone(),
6022 cx.new(|_cx| prompt_store::ProjectContext::default()),
6023 context_server_registry,
6024 templates.clone(),
6025 None,
6026 cx,
6027 )
6028 });
6029
6030 #[allow(clippy::arc_with_non_send_sync)]
6031 let tool = Arc::new(crate::EditFileTool::new(
6032 project,
6033 thread.downgrade(),
6034 language_registry,
6035 templates,
6036 ));
6037 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
6038
6039 let _task = cx.update(|cx| {
6040 tool.run(
6041 ToolInput::resolved(crate::EditFileToolInput {
6042 display_description: "Edit README".to_string(),
6043 path: "root/README.md".into(),
6044 mode: crate::EditFileMode::Edit,
6045 }),
6046 event_stream,
6047 cx,
6048 )
6049 });
6050
6051 cx.run_until_parked();
6052
6053 let event = rx.try_next();
6054 assert!(
6055 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
6056 "expected no authorization request for allowed .md file"
6057 );
6058}
6059
6060#[gpui::test]
6061async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut TestAppContext) {
6062 init_test(cx);
6063
6064 let fs = FakeFs::new(cx.executor());
6065 fs.insert_tree(
6066 "/root",
6067 json!({
6068 ".zed": {
6069 "settings.json": "{}"
6070 },
6071 "README.md": "# Hello"
6072 }),
6073 )
6074 .await;
6075 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
6076
6077 cx.update(|cx| {
6078 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6079 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
6080 agent_settings::AgentSettings::override_global(settings, cx);
6081 });
6082
6083 let context_server_registry =
6084 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
6085 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
6086 let templates = crate::Templates::new();
6087 let thread = cx.new(|cx| {
6088 crate::Thread::new(
6089 project.clone(),
6090 cx.new(|_cx| prompt_store::ProjectContext::default()),
6091 context_server_registry,
6092 templates.clone(),
6093 None,
6094 cx,
6095 )
6096 });
6097
6098 #[allow(clippy::arc_with_non_send_sync)]
6099 let tool = Arc::new(crate::EditFileTool::new(
6100 project,
6101 thread.downgrade(),
6102 language_registry,
6103 templates,
6104 ));
6105
6106 // Editing a file inside .zed/ should still prompt even with global default: allow,
6107 // because local settings paths are sensitive and require confirmation regardless.
6108 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
6109 let _task = cx.update(|cx| {
6110 tool.run(
6111 ToolInput::resolved(crate::EditFileToolInput {
6112 display_description: "Edit local settings".to_string(),
6113 path: "root/.zed/settings.json".into(),
6114 mode: crate::EditFileMode::Edit,
6115 }),
6116 event_stream,
6117 cx,
6118 )
6119 });
6120
6121 let _update = rx.expect_update_fields().await;
6122 let _auth = rx.expect_authorization().await;
6123}
6124
6125#[gpui::test]
6126async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
6127 init_test(cx);
6128
6129 cx.update(|cx| {
6130 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6131 settings.tool_permissions.tools.insert(
6132 FetchTool::NAME.into(),
6133 agent_settings::ToolRules {
6134 default: Some(settings::ToolPermissionMode::Allow),
6135 always_allow: vec![],
6136 always_deny: vec![
6137 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
6138 ],
6139 always_confirm: vec![],
6140 invalid_patterns: vec![],
6141 },
6142 );
6143 agent_settings::AgentSettings::override_global(settings, cx);
6144 });
6145
6146 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
6147
6148 #[allow(clippy::arc_with_non_send_sync)]
6149 let tool = Arc::new(crate::FetchTool::new(http_client));
6150 let (event_stream, _rx) = crate::ToolCallEventStream::test();
6151
6152 let input: crate::FetchToolInput =
6153 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
6154
6155 let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
6156
6157 let result = task.await;
6158 assert!(result.is_err(), "expected fetch to be blocked");
6159 assert!(
6160 result.unwrap_err().contains("blocked"),
6161 "error should mention the fetch was blocked"
6162 );
6163}
6164
6165#[gpui::test]
6166async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
6167 init_test(cx);
6168
6169 cx.update(|cx| {
6170 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6171 settings.tool_permissions.tools.insert(
6172 FetchTool::NAME.into(),
6173 agent_settings::ToolRules {
6174 default: Some(settings::ToolPermissionMode::Confirm),
6175 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
6176 always_deny: vec![],
6177 always_confirm: vec![],
6178 invalid_patterns: vec![],
6179 },
6180 );
6181 agent_settings::AgentSettings::override_global(settings, cx);
6182 });
6183
6184 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
6185
6186 #[allow(clippy::arc_with_non_send_sync)]
6187 let tool = Arc::new(crate::FetchTool::new(http_client));
6188 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
6189
6190 let input: crate::FetchToolInput =
6191 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
6192
6193 let _task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
6194
6195 cx.run_until_parked();
6196
6197 let event = rx.try_next();
6198 assert!(
6199 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
6200 "expected no authorization request for allowed docs.rs URL"
6201 );
6202}
6203
6204#[gpui::test]
6205async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
6206 init_test(cx);
6207 always_allow_tools(cx);
6208
6209 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
6210 let fake_model = model.as_fake();
6211
6212 // Add a tool so we can simulate tool calls
6213 thread.update(cx, |thread, _cx| {
6214 thread.add_tool(EchoTool);
6215 });
6216
6217 // Start a turn by sending a message
6218 let mut events = thread
6219 .update(cx, |thread, cx| {
6220 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
6221 })
6222 .unwrap();
6223 cx.run_until_parked();
6224
6225 // Simulate the model making a tool call
6226 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
6227 LanguageModelToolUse {
6228 id: "tool_1".into(),
6229 name: "echo".into(),
6230 raw_input: r#"{"text": "hello"}"#.into(),
6231 input: json!({"text": "hello"}),
6232 is_input_complete: true,
6233 thought_signature: None,
6234 },
6235 ));
6236 fake_model
6237 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
6238
6239 // Signal that a message is queued before ending the stream
6240 thread.update(cx, |thread, _cx| {
6241 thread.set_has_queued_message(true);
6242 });
6243
6244 // Now end the stream - tool will run, and the boundary check should see the queue
6245 fake_model.end_last_completion_stream();
6246
6247 // Collect all events until the turn stops
6248 let all_events = collect_events_until_stop(&mut events, cx).await;
6249
6250 // Verify we received the tool call event
6251 let tool_call_ids: Vec<_> = all_events
6252 .iter()
6253 .filter_map(|e| match e {
6254 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
6255 _ => None,
6256 })
6257 .collect();
6258 assert_eq!(
6259 tool_call_ids,
6260 vec!["tool_1"],
6261 "Should have received a tool call event for our echo tool"
6262 );
6263
6264 // The turn should have stopped with EndTurn
6265 let stop_reasons = stop_events(all_events);
6266 assert_eq!(
6267 stop_reasons,
6268 vec![acp::StopReason::EndTurn],
6269 "Turn should have ended after tool completion due to queued message"
6270 );
6271
6272 // Verify the queued message flag is still set
6273 thread.update(cx, |thread, _cx| {
6274 assert!(
6275 thread.has_queued_message(),
6276 "Should still have queued message flag set"
6277 );
6278 });
6279
6280 // Thread should be idle now
6281 thread.update(cx, |thread, _cx| {
6282 assert!(
6283 thread.is_turn_complete(),
6284 "Thread should not be running after turn ends"
6285 );
6286 });
6287}
6288
6289#[gpui::test]
6290async fn test_streaming_tool_error_breaks_stream_loop_immediately(cx: &mut TestAppContext) {
6291 init_test(cx);
6292 always_allow_tools(cx);
6293
6294 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
6295 let fake_model = model.as_fake();
6296
6297 thread.update(cx, |thread, _cx| {
6298 thread.add_tool(StreamingFailingEchoTool {
6299 receive_chunks_until_failure: 1,
6300 });
6301 });
6302
6303 let _events = thread
6304 .update(cx, |thread, cx| {
6305 thread.send(
6306 UserMessageId::new(),
6307 ["Use the streaming_failing_echo tool"],
6308 cx,
6309 )
6310 })
6311 .unwrap();
6312 cx.run_until_parked();
6313
6314 let tool_use = LanguageModelToolUse {
6315 id: "call_1".into(),
6316 name: StreamingFailingEchoTool::NAME.into(),
6317 raw_input: "hello".into(),
6318 input: json!({}),
6319 is_input_complete: false,
6320 thought_signature: None,
6321 };
6322
6323 fake_model
6324 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
6325
6326 cx.run_until_parked();
6327
6328 let completions = fake_model.pending_completions();
6329 let last_completion = completions.last().unwrap();
6330
6331 assert_eq!(
6332 last_completion.messages[1..],
6333 vec![
6334 LanguageModelRequestMessage {
6335 role: Role::User,
6336 content: vec!["Use the streaming_failing_echo tool".into()],
6337 cache: false,
6338 reasoning_details: None,
6339 },
6340 LanguageModelRequestMessage {
6341 role: Role::Assistant,
6342 content: vec![language_model::MessageContent::ToolUse(tool_use.clone())],
6343 cache: false,
6344 reasoning_details: None,
6345 },
6346 LanguageModelRequestMessage {
6347 role: Role::User,
6348 content: vec![language_model::MessageContent::ToolResult(
6349 LanguageModelToolResult {
6350 tool_use_id: tool_use.id.clone(),
6351 tool_name: tool_use.name,
6352 is_error: true,
6353 content: "failed".into(),
6354 output: Some("failed".into()),
6355 }
6356 )],
6357 cache: true,
6358 reasoning_details: None,
6359 },
6360 ]
6361 );
6362}
6363
6364#[gpui::test]
6365async fn test_streaming_tool_error_waits_for_prior_tools_to_complete(cx: &mut TestAppContext) {
6366 init_test(cx);
6367 always_allow_tools(cx);
6368
6369 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
6370 let fake_model = model.as_fake();
6371
6372 let (complete_streaming_echo_tool_call_tx, complete_streaming_echo_tool_call_rx) =
6373 oneshot::channel();
6374
6375 thread.update(cx, |thread, _cx| {
6376 thread.add_tool(
6377 StreamingEchoTool::new().with_wait_until_complete(complete_streaming_echo_tool_call_rx),
6378 );
6379 thread.add_tool(StreamingFailingEchoTool {
6380 receive_chunks_until_failure: 1,
6381 });
6382 });
6383
6384 let _events = thread
6385 .update(cx, |thread, cx| {
6386 thread.send(
6387 UserMessageId::new(),
6388 ["Use the streaming_echo tool and the streaming_failing_echo tool"],
6389 cx,
6390 )
6391 })
6392 .unwrap();
6393 cx.run_until_parked();
6394
6395 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
6396 LanguageModelToolUse {
6397 id: "call_1".into(),
6398 name: StreamingEchoTool::NAME.into(),
6399 raw_input: "hello".into(),
6400 input: json!({ "text": "hello" }),
6401 is_input_complete: false,
6402 thought_signature: None,
6403 },
6404 ));
6405 let first_tool_use = LanguageModelToolUse {
6406 id: "call_1".into(),
6407 name: StreamingEchoTool::NAME.into(),
6408 raw_input: "hello world".into(),
6409 input: json!({ "text": "hello world" }),
6410 is_input_complete: true,
6411 thought_signature: None,
6412 };
6413 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
6414 first_tool_use.clone(),
6415 ));
6416 let second_tool_use = LanguageModelToolUse {
6417 name: StreamingFailingEchoTool::NAME.into(),
6418 raw_input: "hello".into(),
6419 input: json!({ "text": "hello" }),
6420 is_input_complete: false,
6421 thought_signature: None,
6422 id: "call_2".into(),
6423 };
6424 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
6425 second_tool_use.clone(),
6426 ));
6427
6428 cx.run_until_parked();
6429
6430 complete_streaming_echo_tool_call_tx.send(()).unwrap();
6431
6432 cx.run_until_parked();
6433
6434 let completions = fake_model.pending_completions();
6435 let last_completion = completions.last().unwrap();
6436
6437 assert_eq!(
6438 last_completion.messages[1..],
6439 vec![
6440 LanguageModelRequestMessage {
6441 role: Role::User,
6442 content: vec![
6443 "Use the streaming_echo tool and the streaming_failing_echo tool".into()
6444 ],
6445 cache: false,
6446 reasoning_details: None,
6447 },
6448 LanguageModelRequestMessage {
6449 role: Role::Assistant,
6450 content: vec![
6451 language_model::MessageContent::ToolUse(first_tool_use.clone()),
6452 language_model::MessageContent::ToolUse(second_tool_use.clone())
6453 ],
6454 cache: false,
6455 reasoning_details: None,
6456 },
6457 LanguageModelRequestMessage {
6458 role: Role::User,
6459 content: vec![
6460 language_model::MessageContent::ToolResult(LanguageModelToolResult {
6461 tool_use_id: second_tool_use.id.clone(),
6462 tool_name: second_tool_use.name,
6463 is_error: true,
6464 content: "failed".into(),
6465 output: Some("failed".into()),
6466 }),
6467 language_model::MessageContent::ToolResult(LanguageModelToolResult {
6468 tool_use_id: first_tool_use.id.clone(),
6469 tool_name: first_tool_use.name,
6470 is_error: false,
6471 content: "hello world".into(),
6472 output: Some("hello world".into()),
6473 }),
6474 ],
6475 cache: true,
6476 reasoning_details: None,
6477 },
6478 ]
6479 );
6480}