1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use feature_flags::FeatureFlagAppExt as _;
11use fs::{FakeFs, Fs};
12use futures::{
13 FutureExt as _, StreamExt,
14 channel::{
15 mpsc::{self, UnboundedReceiver},
16 oneshot,
17 },
18 future::{Fuse, Shared},
19};
20use gpui::{
21 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
22 http_client::FakeHttpClient,
23};
24use indoc::indoc;
25use language_model::{
26 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
27 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
28 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
29 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
30};
31use pretty_assertions::assert_eq;
32use project::{
33 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
34};
35use prompt_store::ProjectContext;
36use reqwest_client::ReqwestClient;
37use schemars::JsonSchema;
38use serde::{Deserialize, Serialize};
39use serde_json::json;
40use settings::{Settings, SettingsStore};
41use std::{
42 path::Path,
43 pin::Pin,
44 rc::Rc,
45 sync::{
46 Arc,
47 atomic::{AtomicBool, Ordering},
48 },
49 time::Duration,
50};
51use util::path;
52
53mod test_tools;
54use test_tools::*;
55
56fn init_test(cx: &mut TestAppContext) {
57 cx.update(|cx| {
58 let settings_store = SettingsStore::test(cx);
59 cx.set_global(settings_store);
60 });
61}
62
63struct FakeTerminalHandle {
64 killed: Arc<AtomicBool>,
65 stopped_by_user: Arc<AtomicBool>,
66 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
67 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
68 output: acp::TerminalOutputResponse,
69 id: acp::TerminalId,
70}
71
72impl FakeTerminalHandle {
73 fn new_never_exits(cx: &mut App) -> Self {
74 let killed = Arc::new(AtomicBool::new(false));
75 let stopped_by_user = Arc::new(AtomicBool::new(false));
76
77 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
78
79 let wait_for_exit = cx
80 .spawn(async move |_cx| {
81 // Wait for the exit signal (sent when kill() is called)
82 let _ = exit_receiver.await;
83 acp::TerminalExitStatus::new()
84 })
85 .shared();
86
87 Self {
88 killed,
89 stopped_by_user,
90 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
91 wait_for_exit,
92 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
93 id: acp::TerminalId::new("fake_terminal".to_string()),
94 }
95 }
96
97 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
98 let killed = Arc::new(AtomicBool::new(false));
99 let stopped_by_user = Arc::new(AtomicBool::new(false));
100 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
101
102 let wait_for_exit = cx
103 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
104 .shared();
105
106 Self {
107 killed,
108 stopped_by_user,
109 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
110 wait_for_exit,
111 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
112 id: acp::TerminalId::new("fake_terminal".to_string()),
113 }
114 }
115
116 fn was_killed(&self) -> bool {
117 self.killed.load(Ordering::SeqCst)
118 }
119
120 fn set_stopped_by_user(&self, stopped: bool) {
121 self.stopped_by_user.store(stopped, Ordering::SeqCst);
122 }
123
124 fn signal_exit(&self) {
125 if let Some(sender) = self.exit_sender.borrow_mut().take() {
126 let _ = sender.send(());
127 }
128 }
129}
130
131impl crate::TerminalHandle for FakeTerminalHandle {
132 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
133 Ok(self.id.clone())
134 }
135
136 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
137 Ok(self.output.clone())
138 }
139
140 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
141 Ok(self.wait_for_exit.clone())
142 }
143
144 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
145 self.killed.store(true, Ordering::SeqCst);
146 self.signal_exit();
147 Ok(())
148 }
149
150 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
151 Ok(self.stopped_by_user.load(Ordering::SeqCst))
152 }
153}
154
155struct FakeThreadEnvironment {
156 handle: Rc<FakeTerminalHandle>,
157}
158
159impl crate::ThreadEnvironment for FakeThreadEnvironment {
160 fn create_terminal(
161 &self,
162 _command: String,
163 _cwd: Option<std::path::PathBuf>,
164 _output_byte_limit: Option<u64>,
165 _cx: &mut AsyncApp,
166 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
167 Task::ready(Ok(self.handle.clone() as Rc<dyn crate::TerminalHandle>))
168 }
169}
170
171/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
172struct MultiTerminalEnvironment {
173 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
174}
175
176impl MultiTerminalEnvironment {
177 fn new() -> Self {
178 Self {
179 handles: std::cell::RefCell::new(Vec::new()),
180 }
181 }
182
183 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
184 self.handles.borrow().clone()
185 }
186}
187
188impl crate::ThreadEnvironment for MultiTerminalEnvironment {
189 fn create_terminal(
190 &self,
191 _command: String,
192 _cwd: Option<std::path::PathBuf>,
193 _output_byte_limit: Option<u64>,
194 cx: &mut AsyncApp,
195 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
196 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
197 self.handles.borrow_mut().push(handle.clone());
198 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
199 }
200}
201
202fn always_allow_tools(cx: &mut TestAppContext) {
203 cx.update(|cx| {
204 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
205 settings.always_allow_tool_actions = true;
206 agent_settings::AgentSettings::override_global(settings, cx);
207 });
208}
209
210#[gpui::test]
211async fn test_echo(cx: &mut TestAppContext) {
212 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
213 let fake_model = model.as_fake();
214
215 let events = thread
216 .update(cx, |thread, cx| {
217 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
218 })
219 .unwrap();
220 cx.run_until_parked();
221 fake_model.send_last_completion_stream_text_chunk("Hello");
222 fake_model
223 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
224 fake_model.end_last_completion_stream();
225
226 let events = events.collect().await;
227 thread.update(cx, |thread, _cx| {
228 assert_eq!(
229 thread.last_message().unwrap().to_markdown(),
230 indoc! {"
231 ## Assistant
232
233 Hello
234 "}
235 )
236 });
237 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
238}
239
240#[gpui::test]
241async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
242 init_test(cx);
243 always_allow_tools(cx);
244
245 let fs = FakeFs::new(cx.executor());
246 let project = Project::test(fs, [], cx).await;
247
248 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
249 let environment = Rc::new(FakeThreadEnvironment {
250 handle: handle.clone(),
251 });
252
253 #[allow(clippy::arc_with_non_send_sync)]
254 let tool = Arc::new(crate::TerminalTool::new(project, environment));
255 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
256
257 let task = cx.update(|cx| {
258 tool.run(
259 crate::TerminalToolInput {
260 command: "sleep 1000".to_string(),
261 cd: ".".to_string(),
262 timeout_ms: Some(5),
263 },
264 event_stream,
265 cx,
266 )
267 });
268
269 let update = rx.expect_update_fields().await;
270 assert!(
271 update.content.iter().any(|blocks| {
272 blocks
273 .iter()
274 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
275 }),
276 "expected tool call update to include terminal content"
277 );
278
279 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
280
281 let deadline = std::time::Instant::now() + Duration::from_millis(500);
282 loop {
283 if let Some(result) = task_future.as_mut().now_or_never() {
284 let result = result.expect("terminal tool task should complete");
285
286 assert!(
287 handle.was_killed(),
288 "expected terminal handle to be killed on timeout"
289 );
290 assert!(
291 result.contains("partial output"),
292 "expected result to include terminal output, got: {result}"
293 );
294 return;
295 }
296
297 if std::time::Instant::now() >= deadline {
298 panic!("timed out waiting for terminal tool task to complete");
299 }
300
301 cx.run_until_parked();
302 cx.background_executor.timer(Duration::from_millis(1)).await;
303 }
304}
305
306#[gpui::test]
307#[ignore]
308async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
309 init_test(cx);
310 always_allow_tools(cx);
311
312 let fs = FakeFs::new(cx.executor());
313 let project = Project::test(fs, [], cx).await;
314
315 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
316 let environment = Rc::new(FakeThreadEnvironment {
317 handle: handle.clone(),
318 });
319
320 #[allow(clippy::arc_with_non_send_sync)]
321 let tool = Arc::new(crate::TerminalTool::new(project, environment));
322 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
323
324 let _task = cx.update(|cx| {
325 tool.run(
326 crate::TerminalToolInput {
327 command: "sleep 1000".to_string(),
328 cd: ".".to_string(),
329 timeout_ms: None,
330 },
331 event_stream,
332 cx,
333 )
334 });
335
336 let update = rx.expect_update_fields().await;
337 assert!(
338 update.content.iter().any(|blocks| {
339 blocks
340 .iter()
341 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
342 }),
343 "expected tool call update to include terminal content"
344 );
345
346 cx.background_executor
347 .timer(Duration::from_millis(25))
348 .await;
349
350 assert!(
351 !handle.was_killed(),
352 "did not expect terminal handle to be killed without a timeout"
353 );
354}
355
356#[gpui::test]
357async fn test_thinking(cx: &mut TestAppContext) {
358 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
359 let fake_model = model.as_fake();
360
361 let events = thread
362 .update(cx, |thread, cx| {
363 thread.send(
364 UserMessageId::new(),
365 [indoc! {"
366 Testing:
367
368 Generate a thinking step where you just think the word 'Think',
369 and have your final answer be 'Hello'
370 "}],
371 cx,
372 )
373 })
374 .unwrap();
375 cx.run_until_parked();
376 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
377 text: "Think".to_string(),
378 signature: None,
379 });
380 fake_model.send_last_completion_stream_text_chunk("Hello");
381 fake_model
382 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
383 fake_model.end_last_completion_stream();
384
385 let events = events.collect().await;
386 thread.update(cx, |thread, _cx| {
387 assert_eq!(
388 thread.last_message().unwrap().to_markdown(),
389 indoc! {"
390 ## Assistant
391
392 <think>Think</think>
393 Hello
394 "}
395 )
396 });
397 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
398}
399
400#[gpui::test]
401async fn test_system_prompt(cx: &mut TestAppContext) {
402 let ThreadTest {
403 model,
404 thread,
405 project_context,
406 ..
407 } = setup(cx, TestModel::Fake).await;
408 let fake_model = model.as_fake();
409
410 project_context.update(cx, |project_context, _cx| {
411 project_context.shell = "test-shell".into()
412 });
413 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
414 thread
415 .update(cx, |thread, cx| {
416 thread.send(UserMessageId::new(), ["abc"], cx)
417 })
418 .unwrap();
419 cx.run_until_parked();
420 let mut pending_completions = fake_model.pending_completions();
421 assert_eq!(
422 pending_completions.len(),
423 1,
424 "unexpected pending completions: {:?}",
425 pending_completions
426 );
427
428 let pending_completion = pending_completions.pop().unwrap();
429 assert_eq!(pending_completion.messages[0].role, Role::System);
430
431 let system_message = &pending_completion.messages[0];
432 let system_prompt = system_message.content[0].to_str().unwrap();
433 assert!(
434 system_prompt.contains("test-shell"),
435 "unexpected system message: {:?}",
436 system_message
437 );
438 assert!(
439 system_prompt.contains("## Fixing Diagnostics"),
440 "unexpected system message: {:?}",
441 system_message
442 );
443}
444
445#[gpui::test]
446async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
447 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
448 let fake_model = model.as_fake();
449
450 thread
451 .update(cx, |thread, cx| {
452 thread.send(UserMessageId::new(), ["abc"], cx)
453 })
454 .unwrap();
455 cx.run_until_parked();
456 let mut pending_completions = fake_model.pending_completions();
457 assert_eq!(
458 pending_completions.len(),
459 1,
460 "unexpected pending completions: {:?}",
461 pending_completions
462 );
463
464 let pending_completion = pending_completions.pop().unwrap();
465 assert_eq!(pending_completion.messages[0].role, Role::System);
466
467 let system_message = &pending_completion.messages[0];
468 let system_prompt = system_message.content[0].to_str().unwrap();
469 assert!(
470 !system_prompt.contains("## Tool Use"),
471 "unexpected system message: {:?}",
472 system_message
473 );
474 assert!(
475 !system_prompt.contains("## Fixing Diagnostics"),
476 "unexpected system message: {:?}",
477 system_message
478 );
479}
480
481#[gpui::test]
482async fn test_prompt_caching(cx: &mut TestAppContext) {
483 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
484 let fake_model = model.as_fake();
485
486 // Send initial user message and verify it's cached
487 thread
488 .update(cx, |thread, cx| {
489 thread.send(UserMessageId::new(), ["Message 1"], cx)
490 })
491 .unwrap();
492 cx.run_until_parked();
493
494 let completion = fake_model.pending_completions().pop().unwrap();
495 assert_eq!(
496 completion.messages[1..],
497 vec![LanguageModelRequestMessage {
498 role: Role::User,
499 content: vec!["Message 1".into()],
500 cache: true,
501 reasoning_details: None,
502 }]
503 );
504 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
505 "Response to Message 1".into(),
506 ));
507 fake_model.end_last_completion_stream();
508 cx.run_until_parked();
509
510 // Send another user message and verify only the latest is cached
511 thread
512 .update(cx, |thread, cx| {
513 thread.send(UserMessageId::new(), ["Message 2"], cx)
514 })
515 .unwrap();
516 cx.run_until_parked();
517
518 let completion = fake_model.pending_completions().pop().unwrap();
519 assert_eq!(
520 completion.messages[1..],
521 vec![
522 LanguageModelRequestMessage {
523 role: Role::User,
524 content: vec!["Message 1".into()],
525 cache: false,
526 reasoning_details: None,
527 },
528 LanguageModelRequestMessage {
529 role: Role::Assistant,
530 content: vec!["Response to Message 1".into()],
531 cache: false,
532 reasoning_details: None,
533 },
534 LanguageModelRequestMessage {
535 role: Role::User,
536 content: vec!["Message 2".into()],
537 cache: true,
538 reasoning_details: None,
539 }
540 ]
541 );
542 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
543 "Response to Message 2".into(),
544 ));
545 fake_model.end_last_completion_stream();
546 cx.run_until_parked();
547
548 // Simulate a tool call and verify that the latest tool result is cached
549 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
550 thread
551 .update(cx, |thread, cx| {
552 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
553 })
554 .unwrap();
555 cx.run_until_parked();
556
557 let tool_use = LanguageModelToolUse {
558 id: "tool_1".into(),
559 name: EchoTool::name().into(),
560 raw_input: json!({"text": "test"}).to_string(),
561 input: json!({"text": "test"}),
562 is_input_complete: true,
563 thought_signature: None,
564 };
565 fake_model
566 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
567 fake_model.end_last_completion_stream();
568 cx.run_until_parked();
569
570 let completion = fake_model.pending_completions().pop().unwrap();
571 let tool_result = LanguageModelToolResult {
572 tool_use_id: "tool_1".into(),
573 tool_name: EchoTool::name().into(),
574 is_error: false,
575 content: "test".into(),
576 output: Some("test".into()),
577 };
578 assert_eq!(
579 completion.messages[1..],
580 vec![
581 LanguageModelRequestMessage {
582 role: Role::User,
583 content: vec!["Message 1".into()],
584 cache: false,
585 reasoning_details: None,
586 },
587 LanguageModelRequestMessage {
588 role: Role::Assistant,
589 content: vec!["Response to Message 1".into()],
590 cache: false,
591 reasoning_details: None,
592 },
593 LanguageModelRequestMessage {
594 role: Role::User,
595 content: vec!["Message 2".into()],
596 cache: false,
597 reasoning_details: None,
598 },
599 LanguageModelRequestMessage {
600 role: Role::Assistant,
601 content: vec!["Response to Message 2".into()],
602 cache: false,
603 reasoning_details: None,
604 },
605 LanguageModelRequestMessage {
606 role: Role::User,
607 content: vec!["Use the echo tool".into()],
608 cache: false,
609 reasoning_details: None,
610 },
611 LanguageModelRequestMessage {
612 role: Role::Assistant,
613 content: vec![MessageContent::ToolUse(tool_use)],
614 cache: false,
615 reasoning_details: None,
616 },
617 LanguageModelRequestMessage {
618 role: Role::User,
619 content: vec![MessageContent::ToolResult(tool_result)],
620 cache: true,
621 reasoning_details: None,
622 }
623 ]
624 );
625}
626
627#[gpui::test]
628#[cfg_attr(not(feature = "e2e"), ignore)]
629async fn test_basic_tool_calls(cx: &mut TestAppContext) {
630 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
631
632 // Test a tool call that's likely to complete *before* streaming stops.
633 let events = thread
634 .update(cx, |thread, cx| {
635 thread.add_tool(EchoTool);
636 thread.send(
637 UserMessageId::new(),
638 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
639 cx,
640 )
641 })
642 .unwrap()
643 .collect()
644 .await;
645 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
646
647 // Test a tool calls that's likely to complete *after* streaming stops.
648 let events = thread
649 .update(cx, |thread, cx| {
650 thread.remove_tool(&EchoTool::name());
651 thread.add_tool(DelayTool);
652 thread.send(
653 UserMessageId::new(),
654 [
655 "Now call the delay tool with 200ms.",
656 "When the timer goes off, then you echo the output of the tool.",
657 ],
658 cx,
659 )
660 })
661 .unwrap()
662 .collect()
663 .await;
664 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
665 thread.update(cx, |thread, _cx| {
666 assert!(
667 thread
668 .last_message()
669 .unwrap()
670 .as_agent_message()
671 .unwrap()
672 .content
673 .iter()
674 .any(|content| {
675 if let AgentMessageContent::Text(text) = content {
676 text.contains("Ding")
677 } else {
678 false
679 }
680 }),
681 "{}",
682 thread.to_markdown()
683 );
684 });
685}
686
687#[gpui::test]
688#[cfg_attr(not(feature = "e2e"), ignore)]
689async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
690 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
691
692 // Test a tool call that's likely to complete *before* streaming stops.
693 let mut events = thread
694 .update(cx, |thread, cx| {
695 thread.add_tool(WordListTool);
696 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
697 })
698 .unwrap();
699
700 let mut saw_partial_tool_use = false;
701 while let Some(event) = events.next().await {
702 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
703 thread.update(cx, |thread, _cx| {
704 // Look for a tool use in the thread's last message
705 let message = thread.last_message().unwrap();
706 let agent_message = message.as_agent_message().unwrap();
707 let last_content = agent_message.content.last().unwrap();
708 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
709 assert_eq!(last_tool_use.name.as_ref(), "word_list");
710 if tool_call.status == acp::ToolCallStatus::Pending {
711 if !last_tool_use.is_input_complete
712 && last_tool_use.input.get("g").is_none()
713 {
714 saw_partial_tool_use = true;
715 }
716 } else {
717 last_tool_use
718 .input
719 .get("a")
720 .expect("'a' has streamed because input is now complete");
721 last_tool_use
722 .input
723 .get("g")
724 .expect("'g' has streamed because input is now complete");
725 }
726 } else {
727 panic!("last content should be a tool use");
728 }
729 });
730 }
731 }
732
733 assert!(
734 saw_partial_tool_use,
735 "should see at least one partially streamed tool use in the history"
736 );
737}
738
739#[gpui::test]
740async fn test_tool_authorization(cx: &mut TestAppContext) {
741 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
742 let fake_model = model.as_fake();
743
744 let mut events = thread
745 .update(cx, |thread, cx| {
746 thread.add_tool(ToolRequiringPermission);
747 thread.send(UserMessageId::new(), ["abc"], cx)
748 })
749 .unwrap();
750 cx.run_until_parked();
751 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
752 LanguageModelToolUse {
753 id: "tool_id_1".into(),
754 name: ToolRequiringPermission::name().into(),
755 raw_input: "{}".into(),
756 input: json!({}),
757 is_input_complete: true,
758 thought_signature: None,
759 },
760 ));
761 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
762 LanguageModelToolUse {
763 id: "tool_id_2".into(),
764 name: ToolRequiringPermission::name().into(),
765 raw_input: "{}".into(),
766 input: json!({}),
767 is_input_complete: true,
768 thought_signature: None,
769 },
770 ));
771 fake_model.end_last_completion_stream();
772 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
773 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
774
775 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
776 tool_call_auth_1
777 .response
778 .send(acp::PermissionOptionId::new("allow"))
779 .unwrap();
780 cx.run_until_parked();
781
782 // Reject the second - send "deny" option_id directly since Deny is now a button
783 tool_call_auth_2
784 .response
785 .send(acp::PermissionOptionId::new("deny"))
786 .unwrap();
787 cx.run_until_parked();
788
789 let completion = fake_model.pending_completions().pop().unwrap();
790 let message = completion.messages.last().unwrap();
791 assert_eq!(
792 message.content,
793 vec![
794 language_model::MessageContent::ToolResult(LanguageModelToolResult {
795 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
796 tool_name: ToolRequiringPermission::name().into(),
797 is_error: false,
798 content: "Allowed".into(),
799 output: Some("Allowed".into())
800 }),
801 language_model::MessageContent::ToolResult(LanguageModelToolResult {
802 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
803 tool_name: ToolRequiringPermission::name().into(),
804 is_error: true,
805 content: "Permission to run tool denied by user".into(),
806 output: Some("Permission to run tool denied by user".into())
807 })
808 ]
809 );
810
811 // Simulate yet another tool call.
812 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
813 LanguageModelToolUse {
814 id: "tool_id_3".into(),
815 name: ToolRequiringPermission::name().into(),
816 raw_input: "{}".into(),
817 input: json!({}),
818 is_input_complete: true,
819 thought_signature: None,
820 },
821 ));
822 fake_model.end_last_completion_stream();
823
824 // Respond by always allowing tools - send transformed option_id
825 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
826 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
827 tool_call_auth_3
828 .response
829 .send(acp::PermissionOptionId::new(
830 "always_allow:tool_requiring_permission",
831 ))
832 .unwrap();
833 cx.run_until_parked();
834 let completion = fake_model.pending_completions().pop().unwrap();
835 let message = completion.messages.last().unwrap();
836 assert_eq!(
837 message.content,
838 vec![language_model::MessageContent::ToolResult(
839 LanguageModelToolResult {
840 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
841 tool_name: ToolRequiringPermission::name().into(),
842 is_error: false,
843 content: "Allowed".into(),
844 output: Some("Allowed".into())
845 }
846 )]
847 );
848
849 // Simulate a final tool call, ensuring we don't trigger authorization.
850 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
851 LanguageModelToolUse {
852 id: "tool_id_4".into(),
853 name: ToolRequiringPermission::name().into(),
854 raw_input: "{}".into(),
855 input: json!({}),
856 is_input_complete: true,
857 thought_signature: None,
858 },
859 ));
860 fake_model.end_last_completion_stream();
861 cx.run_until_parked();
862 let completion = fake_model.pending_completions().pop().unwrap();
863 let message = completion.messages.last().unwrap();
864 assert_eq!(
865 message.content,
866 vec![language_model::MessageContent::ToolResult(
867 LanguageModelToolResult {
868 tool_use_id: "tool_id_4".into(),
869 tool_name: ToolRequiringPermission::name().into(),
870 is_error: false,
871 content: "Allowed".into(),
872 output: Some("Allowed".into())
873 }
874 )]
875 );
876}
877
878#[gpui::test]
879async fn test_tool_hallucination(cx: &mut TestAppContext) {
880 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
881 let fake_model = model.as_fake();
882
883 let mut events = thread
884 .update(cx, |thread, cx| {
885 thread.send(UserMessageId::new(), ["abc"], cx)
886 })
887 .unwrap();
888 cx.run_until_parked();
889 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
890 LanguageModelToolUse {
891 id: "tool_id_1".into(),
892 name: "nonexistent_tool".into(),
893 raw_input: "{}".into(),
894 input: json!({}),
895 is_input_complete: true,
896 thought_signature: None,
897 },
898 ));
899 fake_model.end_last_completion_stream();
900
901 let tool_call = expect_tool_call(&mut events).await;
902 assert_eq!(tool_call.title, "nonexistent_tool");
903 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
904 let update = expect_tool_call_update_fields(&mut events).await;
905 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
906}
907
908async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
909 let event = events
910 .next()
911 .await
912 .expect("no tool call authorization event received")
913 .unwrap();
914 match event {
915 ThreadEvent::ToolCall(tool_call) => tool_call,
916 event => {
917 panic!("Unexpected event {event:?}");
918 }
919 }
920}
921
922async fn expect_tool_call_update_fields(
923 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
924) -> acp::ToolCallUpdate {
925 let event = events
926 .next()
927 .await
928 .expect("no tool call authorization event received")
929 .unwrap();
930 match event {
931 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
932 event => {
933 panic!("Unexpected event {event:?}");
934 }
935 }
936}
937
938async fn next_tool_call_authorization(
939 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
940) -> ToolCallAuthorization {
941 loop {
942 let event = events
943 .next()
944 .await
945 .expect("no tool call authorization event received")
946 .unwrap();
947 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
948 let permission_kinds = tool_call_authorization
949 .options
950 .iter()
951 .map(|o| o.kind)
952 .collect::<Vec<_>>();
953 // Only 2 options now: AllowAlways (for tool) and AllowOnce (granularity only)
954 // Deny is handled by the UI buttons, not as a separate option
955 assert_eq!(
956 permission_kinds,
957 vec![
958 acp::PermissionOptionKind::AllowAlways,
959 acp::PermissionOptionKind::AllowOnce,
960 ]
961 );
962 return tool_call_authorization;
963 }
964 }
965}
966
967#[gpui::test]
968#[cfg_attr(not(feature = "e2e"), ignore)]
969async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
970 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
971
972 // Test concurrent tool calls with different delay times
973 let events = thread
974 .update(cx, |thread, cx| {
975 thread.add_tool(DelayTool);
976 thread.send(
977 UserMessageId::new(),
978 [
979 "Call the delay tool twice in the same message.",
980 "Once with 100ms. Once with 300ms.",
981 "When both timers are complete, describe the outputs.",
982 ],
983 cx,
984 )
985 })
986 .unwrap()
987 .collect()
988 .await;
989
990 let stop_reasons = stop_events(events);
991 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
992
993 thread.update(cx, |thread, _cx| {
994 let last_message = thread.last_message().unwrap();
995 let agent_message = last_message.as_agent_message().unwrap();
996 let text = agent_message
997 .content
998 .iter()
999 .filter_map(|content| {
1000 if let AgentMessageContent::Text(text) = content {
1001 Some(text.as_str())
1002 } else {
1003 None
1004 }
1005 })
1006 .collect::<String>();
1007
1008 assert!(text.contains("Ding"));
1009 });
1010}
1011
1012#[gpui::test]
1013async fn test_profiles(cx: &mut TestAppContext) {
1014 let ThreadTest {
1015 model, thread, fs, ..
1016 } = setup(cx, TestModel::Fake).await;
1017 let fake_model = model.as_fake();
1018
1019 thread.update(cx, |thread, _cx| {
1020 thread.add_tool(DelayTool);
1021 thread.add_tool(EchoTool);
1022 thread.add_tool(InfiniteTool);
1023 });
1024
1025 // Override profiles and wait for settings to be loaded.
1026 fs.insert_file(
1027 paths::settings_file(),
1028 json!({
1029 "agent": {
1030 "profiles": {
1031 "test-1": {
1032 "name": "Test Profile 1",
1033 "tools": {
1034 EchoTool::name(): true,
1035 DelayTool::name(): true,
1036 }
1037 },
1038 "test-2": {
1039 "name": "Test Profile 2",
1040 "tools": {
1041 InfiniteTool::name(): true,
1042 }
1043 }
1044 }
1045 }
1046 })
1047 .to_string()
1048 .into_bytes(),
1049 )
1050 .await;
1051 cx.run_until_parked();
1052
1053 // Test that test-1 profile (default) has echo and delay tools
1054 thread
1055 .update(cx, |thread, cx| {
1056 thread.set_profile(AgentProfileId("test-1".into()), cx);
1057 thread.send(UserMessageId::new(), ["test"], cx)
1058 })
1059 .unwrap();
1060 cx.run_until_parked();
1061
1062 let mut pending_completions = fake_model.pending_completions();
1063 assert_eq!(pending_completions.len(), 1);
1064 let completion = pending_completions.pop().unwrap();
1065 let tool_names: Vec<String> = completion
1066 .tools
1067 .iter()
1068 .map(|tool| tool.name.clone())
1069 .collect();
1070 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
1071 fake_model.end_last_completion_stream();
1072
1073 // Switch to test-2 profile, and verify that it has only the infinite tool.
1074 thread
1075 .update(cx, |thread, cx| {
1076 thread.set_profile(AgentProfileId("test-2".into()), cx);
1077 thread.send(UserMessageId::new(), ["test2"], cx)
1078 })
1079 .unwrap();
1080 cx.run_until_parked();
1081 let mut pending_completions = fake_model.pending_completions();
1082 assert_eq!(pending_completions.len(), 1);
1083 let completion = pending_completions.pop().unwrap();
1084 let tool_names: Vec<String> = completion
1085 .tools
1086 .iter()
1087 .map(|tool| tool.name.clone())
1088 .collect();
1089 assert_eq!(tool_names, vec![InfiniteTool::name()]);
1090}
1091
1092#[gpui::test]
1093async fn test_mcp_tools(cx: &mut TestAppContext) {
1094 let ThreadTest {
1095 model,
1096 thread,
1097 context_server_store,
1098 fs,
1099 ..
1100 } = setup(cx, TestModel::Fake).await;
1101 let fake_model = model.as_fake();
1102
1103 // Override profiles and wait for settings to be loaded.
1104 fs.insert_file(
1105 paths::settings_file(),
1106 json!({
1107 "agent": {
1108 "always_allow_tool_actions": true,
1109 "profiles": {
1110 "test": {
1111 "name": "Test Profile",
1112 "enable_all_context_servers": true,
1113 "tools": {
1114 EchoTool::name(): true,
1115 }
1116 },
1117 }
1118 }
1119 })
1120 .to_string()
1121 .into_bytes(),
1122 )
1123 .await;
1124 cx.run_until_parked();
1125 thread.update(cx, |thread, cx| {
1126 thread.set_profile(AgentProfileId("test".into()), cx)
1127 });
1128
1129 let mut mcp_tool_calls = setup_context_server(
1130 "test_server",
1131 vec![context_server::types::Tool {
1132 name: "echo".into(),
1133 description: None,
1134 input_schema: serde_json::to_value(EchoTool::input_schema(
1135 LanguageModelToolSchemaFormat::JsonSchema,
1136 ))
1137 .unwrap(),
1138 output_schema: None,
1139 annotations: None,
1140 }],
1141 &context_server_store,
1142 cx,
1143 );
1144
1145 let events = thread.update(cx, |thread, cx| {
1146 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1147 });
1148 cx.run_until_parked();
1149
1150 // Simulate the model calling the MCP tool.
1151 let completion = fake_model.pending_completions().pop().unwrap();
1152 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1153 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1154 LanguageModelToolUse {
1155 id: "tool_1".into(),
1156 name: "echo".into(),
1157 raw_input: json!({"text": "test"}).to_string(),
1158 input: json!({"text": "test"}),
1159 is_input_complete: true,
1160 thought_signature: None,
1161 },
1162 ));
1163 fake_model.end_last_completion_stream();
1164 cx.run_until_parked();
1165
1166 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1167 assert_eq!(tool_call_params.name, "echo");
1168 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1169 tool_call_response
1170 .send(context_server::types::CallToolResponse {
1171 content: vec![context_server::types::ToolResponseContent::Text {
1172 text: "test".into(),
1173 }],
1174 is_error: None,
1175 meta: None,
1176 structured_content: None,
1177 })
1178 .unwrap();
1179 cx.run_until_parked();
1180
1181 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1182 fake_model.send_last_completion_stream_text_chunk("Done!");
1183 fake_model.end_last_completion_stream();
1184 events.collect::<Vec<_>>().await;
1185
1186 // Send again after adding the echo tool, ensuring the name collision is resolved.
1187 let events = thread.update(cx, |thread, cx| {
1188 thread.add_tool(EchoTool);
1189 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1190 });
1191 cx.run_until_parked();
1192 let completion = fake_model.pending_completions().pop().unwrap();
1193 assert_eq!(
1194 tool_names_for_completion(&completion),
1195 vec!["echo", "test_server_echo"]
1196 );
1197 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1198 LanguageModelToolUse {
1199 id: "tool_2".into(),
1200 name: "test_server_echo".into(),
1201 raw_input: json!({"text": "mcp"}).to_string(),
1202 input: json!({"text": "mcp"}),
1203 is_input_complete: true,
1204 thought_signature: None,
1205 },
1206 ));
1207 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1208 LanguageModelToolUse {
1209 id: "tool_3".into(),
1210 name: "echo".into(),
1211 raw_input: json!({"text": "native"}).to_string(),
1212 input: json!({"text": "native"}),
1213 is_input_complete: true,
1214 thought_signature: None,
1215 },
1216 ));
1217 fake_model.end_last_completion_stream();
1218 cx.run_until_parked();
1219
1220 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1221 assert_eq!(tool_call_params.name, "echo");
1222 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1223 tool_call_response
1224 .send(context_server::types::CallToolResponse {
1225 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1226 is_error: None,
1227 meta: None,
1228 structured_content: None,
1229 })
1230 .unwrap();
1231 cx.run_until_parked();
1232
1233 // Ensure the tool results were inserted with the correct names.
1234 let completion = fake_model.pending_completions().pop().unwrap();
1235 assert_eq!(
1236 completion.messages.last().unwrap().content,
1237 vec![
1238 MessageContent::ToolResult(LanguageModelToolResult {
1239 tool_use_id: "tool_3".into(),
1240 tool_name: "echo".into(),
1241 is_error: false,
1242 content: "native".into(),
1243 output: Some("native".into()),
1244 },),
1245 MessageContent::ToolResult(LanguageModelToolResult {
1246 tool_use_id: "tool_2".into(),
1247 tool_name: "test_server_echo".into(),
1248 is_error: false,
1249 content: "mcp".into(),
1250 output: Some("mcp".into()),
1251 },),
1252 ]
1253 );
1254 fake_model.end_last_completion_stream();
1255 events.collect::<Vec<_>>().await;
1256}
1257
1258#[gpui::test]
1259async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1260 let ThreadTest {
1261 model,
1262 thread,
1263 context_server_store,
1264 fs,
1265 ..
1266 } = setup(cx, TestModel::Fake).await;
1267 let fake_model = model.as_fake();
1268
1269 // Set up a profile with all tools enabled
1270 fs.insert_file(
1271 paths::settings_file(),
1272 json!({
1273 "agent": {
1274 "profiles": {
1275 "test": {
1276 "name": "Test Profile",
1277 "enable_all_context_servers": true,
1278 "tools": {
1279 EchoTool::name(): true,
1280 DelayTool::name(): true,
1281 WordListTool::name(): true,
1282 ToolRequiringPermission::name(): true,
1283 InfiniteTool::name(): true,
1284 }
1285 },
1286 }
1287 }
1288 })
1289 .to_string()
1290 .into_bytes(),
1291 )
1292 .await;
1293 cx.run_until_parked();
1294
1295 thread.update(cx, |thread, cx| {
1296 thread.set_profile(AgentProfileId("test".into()), cx);
1297 thread.add_tool(EchoTool);
1298 thread.add_tool(DelayTool);
1299 thread.add_tool(WordListTool);
1300 thread.add_tool(ToolRequiringPermission);
1301 thread.add_tool(InfiniteTool);
1302 });
1303
1304 // Set up multiple context servers with some overlapping tool names
1305 let _server1_calls = setup_context_server(
1306 "xxx",
1307 vec![
1308 context_server::types::Tool {
1309 name: "echo".into(), // Conflicts with native EchoTool
1310 description: None,
1311 input_schema: serde_json::to_value(EchoTool::input_schema(
1312 LanguageModelToolSchemaFormat::JsonSchema,
1313 ))
1314 .unwrap(),
1315 output_schema: None,
1316 annotations: None,
1317 },
1318 context_server::types::Tool {
1319 name: "unique_tool_1".into(),
1320 description: None,
1321 input_schema: json!({"type": "object", "properties": {}}),
1322 output_schema: None,
1323 annotations: None,
1324 },
1325 ],
1326 &context_server_store,
1327 cx,
1328 );
1329
1330 let _server2_calls = setup_context_server(
1331 "yyy",
1332 vec![
1333 context_server::types::Tool {
1334 name: "echo".into(), // Also conflicts with native EchoTool
1335 description: None,
1336 input_schema: serde_json::to_value(EchoTool::input_schema(
1337 LanguageModelToolSchemaFormat::JsonSchema,
1338 ))
1339 .unwrap(),
1340 output_schema: None,
1341 annotations: None,
1342 },
1343 context_server::types::Tool {
1344 name: "unique_tool_2".into(),
1345 description: None,
1346 input_schema: json!({"type": "object", "properties": {}}),
1347 output_schema: None,
1348 annotations: None,
1349 },
1350 context_server::types::Tool {
1351 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1352 description: None,
1353 input_schema: json!({"type": "object", "properties": {}}),
1354 output_schema: None,
1355 annotations: None,
1356 },
1357 context_server::types::Tool {
1358 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1359 description: None,
1360 input_schema: json!({"type": "object", "properties": {}}),
1361 output_schema: None,
1362 annotations: None,
1363 },
1364 ],
1365 &context_server_store,
1366 cx,
1367 );
1368 let _server3_calls = setup_context_server(
1369 "zzz",
1370 vec![
1371 context_server::types::Tool {
1372 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1373 description: None,
1374 input_schema: json!({"type": "object", "properties": {}}),
1375 output_schema: None,
1376 annotations: None,
1377 },
1378 context_server::types::Tool {
1379 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1380 description: None,
1381 input_schema: json!({"type": "object", "properties": {}}),
1382 output_schema: None,
1383 annotations: None,
1384 },
1385 context_server::types::Tool {
1386 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1387 description: None,
1388 input_schema: json!({"type": "object", "properties": {}}),
1389 output_schema: None,
1390 annotations: None,
1391 },
1392 ],
1393 &context_server_store,
1394 cx,
1395 );
1396
1397 thread
1398 .update(cx, |thread, cx| {
1399 thread.send(UserMessageId::new(), ["Go"], cx)
1400 })
1401 .unwrap();
1402 cx.run_until_parked();
1403 let completion = fake_model.pending_completions().pop().unwrap();
1404 assert_eq!(
1405 tool_names_for_completion(&completion),
1406 vec![
1407 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1408 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1409 "delay",
1410 "echo",
1411 "infinite",
1412 "tool_requiring_permission",
1413 "unique_tool_1",
1414 "unique_tool_2",
1415 "word_list",
1416 "xxx_echo",
1417 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1418 "yyy_echo",
1419 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1420 ]
1421 );
1422}
1423
1424#[gpui::test]
1425#[cfg_attr(not(feature = "e2e"), ignore)]
1426async fn test_cancellation(cx: &mut TestAppContext) {
1427 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1428
1429 let mut events = thread
1430 .update(cx, |thread, cx| {
1431 thread.add_tool(InfiniteTool);
1432 thread.add_tool(EchoTool);
1433 thread.send(
1434 UserMessageId::new(),
1435 ["Call the echo tool, then call the infinite tool, then explain their output"],
1436 cx,
1437 )
1438 })
1439 .unwrap();
1440
1441 // Wait until both tools are called.
1442 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1443 let mut echo_id = None;
1444 let mut echo_completed = false;
1445 while let Some(event) = events.next().await {
1446 match event.unwrap() {
1447 ThreadEvent::ToolCall(tool_call) => {
1448 assert_eq!(tool_call.title, expected_tools.remove(0));
1449 if tool_call.title == "Echo" {
1450 echo_id = Some(tool_call.tool_call_id);
1451 }
1452 }
1453 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1454 acp::ToolCallUpdate {
1455 tool_call_id,
1456 fields:
1457 acp::ToolCallUpdateFields {
1458 status: Some(acp::ToolCallStatus::Completed),
1459 ..
1460 },
1461 ..
1462 },
1463 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1464 echo_completed = true;
1465 }
1466 _ => {}
1467 }
1468
1469 if expected_tools.is_empty() && echo_completed {
1470 break;
1471 }
1472 }
1473
1474 // Cancel the current send and ensure that the event stream is closed, even
1475 // if one of the tools is still running.
1476 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1477 let events = events.collect::<Vec<_>>().await;
1478 let last_event = events.last();
1479 assert!(
1480 matches!(
1481 last_event,
1482 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1483 ),
1484 "unexpected event {last_event:?}"
1485 );
1486
1487 // Ensure we can still send a new message after cancellation.
1488 let events = thread
1489 .update(cx, |thread, cx| {
1490 thread.send(
1491 UserMessageId::new(),
1492 ["Testing: reply with 'Hello' then stop."],
1493 cx,
1494 )
1495 })
1496 .unwrap()
1497 .collect::<Vec<_>>()
1498 .await;
1499 thread.update(cx, |thread, _cx| {
1500 let message = thread.last_message().unwrap();
1501 let agent_message = message.as_agent_message().unwrap();
1502 assert_eq!(
1503 agent_message.content,
1504 vec![AgentMessageContent::Text("Hello".to_string())]
1505 );
1506 });
1507 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1508}
1509
1510#[gpui::test]
1511async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1512 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1513 always_allow_tools(cx);
1514 let fake_model = model.as_fake();
1515
1516 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1517 let environment = Rc::new(FakeThreadEnvironment {
1518 handle: handle.clone(),
1519 });
1520
1521 let mut events = thread
1522 .update(cx, |thread, cx| {
1523 thread.add_tool(crate::TerminalTool::new(
1524 thread.project().clone(),
1525 environment,
1526 ));
1527 thread.send(UserMessageId::new(), ["run a command"], cx)
1528 })
1529 .unwrap();
1530
1531 cx.run_until_parked();
1532
1533 // Simulate the model calling the terminal tool
1534 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1535 LanguageModelToolUse {
1536 id: "terminal_tool_1".into(),
1537 name: "terminal".into(),
1538 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1539 input: json!({"command": "sleep 1000", "cd": "."}),
1540 is_input_complete: true,
1541 thought_signature: None,
1542 },
1543 ));
1544 fake_model.end_last_completion_stream();
1545
1546 // Wait for the terminal tool to start running
1547 wait_for_terminal_tool_started(&mut events, cx).await;
1548
1549 // Cancel the thread while the terminal is running
1550 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1551
1552 // Collect remaining events, driving the executor to let cancellation complete
1553 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1554
1555 // Verify the terminal was killed
1556 assert!(
1557 handle.was_killed(),
1558 "expected terminal handle to be killed on cancellation"
1559 );
1560
1561 // Verify we got a cancellation stop event
1562 assert_eq!(
1563 stop_events(remaining_events),
1564 vec![acp::StopReason::Cancelled],
1565 );
1566
1567 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1568 thread.update(cx, |thread, _cx| {
1569 let message = thread.last_message().unwrap();
1570 let agent_message = message.as_agent_message().unwrap();
1571
1572 let tool_use = agent_message
1573 .content
1574 .iter()
1575 .find_map(|content| match content {
1576 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1577 _ => None,
1578 })
1579 .expect("expected tool use in agent message");
1580
1581 let tool_result = agent_message
1582 .tool_results
1583 .get(&tool_use.id)
1584 .expect("expected tool result");
1585
1586 let result_text = match &tool_result.content {
1587 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
1588 _ => panic!("expected text content in tool result"),
1589 };
1590
1591 // "partial output" comes from FakeTerminalHandle's output field
1592 assert!(
1593 result_text.contains("partial output"),
1594 "expected tool result to contain terminal output, got: {result_text}"
1595 );
1596 // Match the actual format from process_content in terminal_tool.rs
1597 assert!(
1598 result_text.contains("The user stopped this command"),
1599 "expected tool result to indicate user stopped, got: {result_text}"
1600 );
1601 });
1602
1603 // Verify we can send a new message after cancellation
1604 verify_thread_recovery(&thread, &fake_model, cx).await;
1605}
1606
1607#[gpui::test]
1608async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
1609 // This test verifies that tools which properly handle cancellation via
1610 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
1611 // to cancellation and report that they were cancelled.
1612 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1613 always_allow_tools(cx);
1614 let fake_model = model.as_fake();
1615
1616 let (tool, was_cancelled) = CancellationAwareTool::new();
1617
1618 let mut events = thread
1619 .update(cx, |thread, cx| {
1620 thread.add_tool(tool);
1621 thread.send(
1622 UserMessageId::new(),
1623 ["call the cancellation aware tool"],
1624 cx,
1625 )
1626 })
1627 .unwrap();
1628
1629 cx.run_until_parked();
1630
1631 // Simulate the model calling the cancellation-aware tool
1632 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1633 LanguageModelToolUse {
1634 id: "cancellation_aware_1".into(),
1635 name: "cancellation_aware".into(),
1636 raw_input: r#"{}"#.into(),
1637 input: json!({}),
1638 is_input_complete: true,
1639 thought_signature: None,
1640 },
1641 ));
1642 fake_model.end_last_completion_stream();
1643
1644 cx.run_until_parked();
1645
1646 // Wait for the tool call to be reported
1647 let mut tool_started = false;
1648 let deadline = cx.executor().num_cpus() * 100;
1649 for _ in 0..deadline {
1650 cx.run_until_parked();
1651
1652 while let Some(Some(event)) = events.next().now_or_never() {
1653 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
1654 if tool_call.title == "Cancellation Aware Tool" {
1655 tool_started = true;
1656 break;
1657 }
1658 }
1659 }
1660
1661 if tool_started {
1662 break;
1663 }
1664
1665 cx.background_executor
1666 .timer(Duration::from_millis(10))
1667 .await;
1668 }
1669 assert!(tool_started, "expected cancellation aware tool to start");
1670
1671 // Cancel the thread and wait for it to complete
1672 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
1673
1674 // The cancel task should complete promptly because the tool handles cancellation
1675 let timeout = cx.background_executor.timer(Duration::from_secs(5));
1676 futures::select! {
1677 _ = cancel_task.fuse() => {}
1678 _ = timeout.fuse() => {
1679 panic!("cancel task timed out - tool did not respond to cancellation");
1680 }
1681 }
1682
1683 // Verify the tool detected cancellation via its flag
1684 assert!(
1685 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
1686 "tool should have detected cancellation via event_stream.cancelled_by_user()"
1687 );
1688
1689 // Collect remaining events
1690 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1691
1692 // Verify we got a cancellation stop event
1693 assert_eq!(
1694 stop_events(remaining_events),
1695 vec![acp::StopReason::Cancelled],
1696 );
1697
1698 // Verify we can send a new message after cancellation
1699 verify_thread_recovery(&thread, &fake_model, cx).await;
1700}
1701
1702/// Helper to verify thread can recover after cancellation by sending a simple message.
1703async fn verify_thread_recovery(
1704 thread: &Entity<Thread>,
1705 fake_model: &FakeLanguageModel,
1706 cx: &mut TestAppContext,
1707) {
1708 let events = thread
1709 .update(cx, |thread, cx| {
1710 thread.send(
1711 UserMessageId::new(),
1712 ["Testing: reply with 'Hello' then stop."],
1713 cx,
1714 )
1715 })
1716 .unwrap();
1717 cx.run_until_parked();
1718 fake_model.send_last_completion_stream_text_chunk("Hello");
1719 fake_model
1720 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1721 fake_model.end_last_completion_stream();
1722
1723 let events = events.collect::<Vec<_>>().await;
1724 thread.update(cx, |thread, _cx| {
1725 let message = thread.last_message().unwrap();
1726 let agent_message = message.as_agent_message().unwrap();
1727 assert_eq!(
1728 agent_message.content,
1729 vec![AgentMessageContent::Text("Hello".to_string())]
1730 );
1731 });
1732 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1733}
1734
1735/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
1736async fn wait_for_terminal_tool_started(
1737 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1738 cx: &mut TestAppContext,
1739) {
1740 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
1741 for _ in 0..deadline {
1742 cx.run_until_parked();
1743
1744 while let Some(Some(event)) = events.next().now_or_never() {
1745 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1746 update,
1747 ))) = &event
1748 {
1749 if update.fields.content.as_ref().is_some_and(|content| {
1750 content
1751 .iter()
1752 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1753 }) {
1754 return;
1755 }
1756 }
1757 }
1758
1759 cx.background_executor
1760 .timer(Duration::from_millis(10))
1761 .await;
1762 }
1763 panic!("terminal tool did not start within the expected time");
1764}
1765
1766/// Collects events until a Stop event is received, driving the executor to completion.
1767async fn collect_events_until_stop(
1768 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1769 cx: &mut TestAppContext,
1770) -> Vec<Result<ThreadEvent>> {
1771 let mut collected = Vec::new();
1772 let deadline = cx.executor().num_cpus() * 200;
1773
1774 for _ in 0..deadline {
1775 cx.executor().advance_clock(Duration::from_millis(10));
1776 cx.run_until_parked();
1777
1778 while let Some(Some(event)) = events.next().now_or_never() {
1779 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
1780 collected.push(event);
1781 if is_stop {
1782 return collected;
1783 }
1784 }
1785 }
1786 panic!(
1787 "did not receive Stop event within the expected time; collected {} events",
1788 collected.len()
1789 );
1790}
1791
1792#[gpui::test]
1793async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
1794 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1795 always_allow_tools(cx);
1796 let fake_model = model.as_fake();
1797
1798 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1799 let environment = Rc::new(FakeThreadEnvironment {
1800 handle: handle.clone(),
1801 });
1802
1803 let message_id = UserMessageId::new();
1804 let mut events = thread
1805 .update(cx, |thread, cx| {
1806 thread.add_tool(crate::TerminalTool::new(
1807 thread.project().clone(),
1808 environment,
1809 ));
1810 thread.send(message_id.clone(), ["run a command"], cx)
1811 })
1812 .unwrap();
1813
1814 cx.run_until_parked();
1815
1816 // Simulate the model calling the terminal tool
1817 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1818 LanguageModelToolUse {
1819 id: "terminal_tool_1".into(),
1820 name: "terminal".into(),
1821 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1822 input: json!({"command": "sleep 1000", "cd": "."}),
1823 is_input_complete: true,
1824 thought_signature: None,
1825 },
1826 ));
1827 fake_model.end_last_completion_stream();
1828
1829 // Wait for the terminal tool to start running
1830 wait_for_terminal_tool_started(&mut events, cx).await;
1831
1832 // Truncate the thread while the terminal is running
1833 thread
1834 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1835 .unwrap();
1836
1837 // Drive the executor to let cancellation complete
1838 let _ = collect_events_until_stop(&mut events, cx).await;
1839
1840 // Verify the terminal was killed
1841 assert!(
1842 handle.was_killed(),
1843 "expected terminal handle to be killed on truncate"
1844 );
1845
1846 // Verify the thread is empty after truncation
1847 thread.update(cx, |thread, _cx| {
1848 assert_eq!(
1849 thread.to_markdown(),
1850 "",
1851 "expected thread to be empty after truncating the only message"
1852 );
1853 });
1854
1855 // Verify we can send a new message after truncation
1856 verify_thread_recovery(&thread, &fake_model, cx).await;
1857}
1858
1859#[gpui::test]
1860async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
1861 // Tests that cancellation properly kills all running terminal tools when multiple are active.
1862 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1863 always_allow_tools(cx);
1864 let fake_model = model.as_fake();
1865
1866 let environment = Rc::new(MultiTerminalEnvironment::new());
1867
1868 let mut events = thread
1869 .update(cx, |thread, cx| {
1870 thread.add_tool(crate::TerminalTool::new(
1871 thread.project().clone(),
1872 environment.clone(),
1873 ));
1874 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
1875 })
1876 .unwrap();
1877
1878 cx.run_until_parked();
1879
1880 // Simulate the model calling two terminal tools
1881 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1882 LanguageModelToolUse {
1883 id: "terminal_tool_1".into(),
1884 name: "terminal".into(),
1885 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1886 input: json!({"command": "sleep 1000", "cd": "."}),
1887 is_input_complete: true,
1888 thought_signature: None,
1889 },
1890 ));
1891 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1892 LanguageModelToolUse {
1893 id: "terminal_tool_2".into(),
1894 name: "terminal".into(),
1895 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
1896 input: json!({"command": "sleep 2000", "cd": "."}),
1897 is_input_complete: true,
1898 thought_signature: None,
1899 },
1900 ));
1901 fake_model.end_last_completion_stream();
1902
1903 // Wait for both terminal tools to start by counting terminal content updates
1904 let mut terminals_started = 0;
1905 let deadline = cx.executor().num_cpus() * 100;
1906 for _ in 0..deadline {
1907 cx.run_until_parked();
1908
1909 while let Some(Some(event)) = events.next().now_or_never() {
1910 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1911 update,
1912 ))) = &event
1913 {
1914 if update.fields.content.as_ref().is_some_and(|content| {
1915 content
1916 .iter()
1917 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1918 }) {
1919 terminals_started += 1;
1920 if terminals_started >= 2 {
1921 break;
1922 }
1923 }
1924 }
1925 }
1926 if terminals_started >= 2 {
1927 break;
1928 }
1929
1930 cx.background_executor
1931 .timer(Duration::from_millis(10))
1932 .await;
1933 }
1934 assert!(
1935 terminals_started >= 2,
1936 "expected 2 terminal tools to start, got {terminals_started}"
1937 );
1938
1939 // Cancel the thread while both terminals are running
1940 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1941
1942 // Collect remaining events
1943 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1944
1945 // Verify both terminal handles were killed
1946 let handles = environment.handles();
1947 assert_eq!(
1948 handles.len(),
1949 2,
1950 "expected 2 terminal handles to be created"
1951 );
1952 assert!(
1953 handles[0].was_killed(),
1954 "expected first terminal handle to be killed on cancellation"
1955 );
1956 assert!(
1957 handles[1].was_killed(),
1958 "expected second terminal handle to be killed on cancellation"
1959 );
1960
1961 // Verify we got a cancellation stop event
1962 assert_eq!(
1963 stop_events(remaining_events),
1964 vec![acp::StopReason::Cancelled],
1965 );
1966}
1967
1968#[gpui::test]
1969async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
1970 // Tests that clicking the stop button on the terminal card (as opposed to the main
1971 // cancel button) properly reports user stopped via the was_stopped_by_user path.
1972 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1973 always_allow_tools(cx);
1974 let fake_model = model.as_fake();
1975
1976 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1977 let environment = Rc::new(FakeThreadEnvironment {
1978 handle: handle.clone(),
1979 });
1980
1981 let mut events = thread
1982 .update(cx, |thread, cx| {
1983 thread.add_tool(crate::TerminalTool::new(
1984 thread.project().clone(),
1985 environment,
1986 ));
1987 thread.send(UserMessageId::new(), ["run a command"], cx)
1988 })
1989 .unwrap();
1990
1991 cx.run_until_parked();
1992
1993 // Simulate the model calling the terminal tool
1994 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1995 LanguageModelToolUse {
1996 id: "terminal_tool_1".into(),
1997 name: "terminal".into(),
1998 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1999 input: json!({"command": "sleep 1000", "cd": "."}),
2000 is_input_complete: true,
2001 thought_signature: None,
2002 },
2003 ));
2004 fake_model.end_last_completion_stream();
2005
2006 // Wait for the terminal tool to start running
2007 wait_for_terminal_tool_started(&mut events, cx).await;
2008
2009 // Simulate user clicking stop on the terminal card itself.
2010 // This sets the flag and signals exit (simulating what the real UI would do).
2011 handle.set_stopped_by_user(true);
2012 handle.killed.store(true, Ordering::SeqCst);
2013 handle.signal_exit();
2014
2015 // Wait for the tool to complete
2016 cx.run_until_parked();
2017
2018 // The thread continues after tool completion - simulate the model ending its turn
2019 fake_model
2020 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2021 fake_model.end_last_completion_stream();
2022
2023 // Collect remaining events
2024 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2025
2026 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2027 assert_eq!(
2028 stop_events(remaining_events),
2029 vec![acp::StopReason::EndTurn],
2030 );
2031
2032 // Verify the tool result indicates user stopped
2033 thread.update(cx, |thread, _cx| {
2034 let message = thread.last_message().unwrap();
2035 let agent_message = message.as_agent_message().unwrap();
2036
2037 let tool_use = agent_message
2038 .content
2039 .iter()
2040 .find_map(|content| match content {
2041 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2042 _ => None,
2043 })
2044 .expect("expected tool use in agent message");
2045
2046 let tool_result = agent_message
2047 .tool_results
2048 .get(&tool_use.id)
2049 .expect("expected tool result");
2050
2051 let result_text = match &tool_result.content {
2052 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2053 _ => panic!("expected text content in tool result"),
2054 };
2055
2056 assert!(
2057 result_text.contains("The user stopped this command"),
2058 "expected tool result to indicate user stopped, got: {result_text}"
2059 );
2060 });
2061}
2062
2063#[gpui::test]
2064async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2065 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2066 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2067 always_allow_tools(cx);
2068 let fake_model = model.as_fake();
2069
2070 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2071 let environment = Rc::new(FakeThreadEnvironment {
2072 handle: handle.clone(),
2073 });
2074
2075 let mut events = thread
2076 .update(cx, |thread, cx| {
2077 thread.add_tool(crate::TerminalTool::new(
2078 thread.project().clone(),
2079 environment,
2080 ));
2081 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2082 })
2083 .unwrap();
2084
2085 cx.run_until_parked();
2086
2087 // Simulate the model calling the terminal tool with a short timeout
2088 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2089 LanguageModelToolUse {
2090 id: "terminal_tool_1".into(),
2091 name: "terminal".into(),
2092 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2093 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2094 is_input_complete: true,
2095 thought_signature: None,
2096 },
2097 ));
2098 fake_model.end_last_completion_stream();
2099
2100 // Wait for the terminal tool to start running
2101 wait_for_terminal_tool_started(&mut events, cx).await;
2102
2103 // Advance clock past the timeout
2104 cx.executor().advance_clock(Duration::from_millis(200));
2105 cx.run_until_parked();
2106
2107 // The thread continues after tool completion - simulate the model ending its turn
2108 fake_model
2109 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2110 fake_model.end_last_completion_stream();
2111
2112 // Collect remaining events
2113 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2114
2115 // Verify the terminal was killed due to timeout
2116 assert!(
2117 handle.was_killed(),
2118 "expected terminal handle to be killed on timeout"
2119 );
2120
2121 // Verify we got an EndTurn (the tool completed, just with timeout)
2122 assert_eq!(
2123 stop_events(remaining_events),
2124 vec![acp::StopReason::EndTurn],
2125 );
2126
2127 // Verify the tool result indicates timeout, not user stopped
2128 thread.update(cx, |thread, _cx| {
2129 let message = thread.last_message().unwrap();
2130 let agent_message = message.as_agent_message().unwrap();
2131
2132 let tool_use = agent_message
2133 .content
2134 .iter()
2135 .find_map(|content| match content {
2136 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2137 _ => None,
2138 })
2139 .expect("expected tool use in agent message");
2140
2141 let tool_result = agent_message
2142 .tool_results
2143 .get(&tool_use.id)
2144 .expect("expected tool result");
2145
2146 let result_text = match &tool_result.content {
2147 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2148 _ => panic!("expected text content in tool result"),
2149 };
2150
2151 assert!(
2152 result_text.contains("timed out"),
2153 "expected tool result to indicate timeout, got: {result_text}"
2154 );
2155 assert!(
2156 !result_text.contains("The user stopped"),
2157 "tool result should not mention user stopped when it timed out, got: {result_text}"
2158 );
2159 });
2160}
2161
2162#[gpui::test]
2163async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2164 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2165 let fake_model = model.as_fake();
2166
2167 let events_1 = thread
2168 .update(cx, |thread, cx| {
2169 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2170 })
2171 .unwrap();
2172 cx.run_until_parked();
2173 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2174 cx.run_until_parked();
2175
2176 let events_2 = thread
2177 .update(cx, |thread, cx| {
2178 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2179 })
2180 .unwrap();
2181 cx.run_until_parked();
2182 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2183 fake_model
2184 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2185 fake_model.end_last_completion_stream();
2186
2187 let events_1 = events_1.collect::<Vec<_>>().await;
2188 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2189 let events_2 = events_2.collect::<Vec<_>>().await;
2190 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2191}
2192
2193#[gpui::test]
2194async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2195 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2196 let fake_model = model.as_fake();
2197
2198 let events_1 = thread
2199 .update(cx, |thread, cx| {
2200 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2201 })
2202 .unwrap();
2203 cx.run_until_parked();
2204 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2205 fake_model
2206 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2207 fake_model.end_last_completion_stream();
2208 let events_1 = events_1.collect::<Vec<_>>().await;
2209
2210 let events_2 = thread
2211 .update(cx, |thread, cx| {
2212 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2213 })
2214 .unwrap();
2215 cx.run_until_parked();
2216 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2217 fake_model
2218 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2219 fake_model.end_last_completion_stream();
2220 let events_2 = events_2.collect::<Vec<_>>().await;
2221
2222 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2223 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2224}
2225
2226#[gpui::test]
2227async fn test_refusal(cx: &mut TestAppContext) {
2228 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2229 let fake_model = model.as_fake();
2230
2231 let events = thread
2232 .update(cx, |thread, cx| {
2233 thread.send(UserMessageId::new(), ["Hello"], cx)
2234 })
2235 .unwrap();
2236 cx.run_until_parked();
2237 thread.read_with(cx, |thread, _| {
2238 assert_eq!(
2239 thread.to_markdown(),
2240 indoc! {"
2241 ## User
2242
2243 Hello
2244 "}
2245 );
2246 });
2247
2248 fake_model.send_last_completion_stream_text_chunk("Hey!");
2249 cx.run_until_parked();
2250 thread.read_with(cx, |thread, _| {
2251 assert_eq!(
2252 thread.to_markdown(),
2253 indoc! {"
2254 ## User
2255
2256 Hello
2257
2258 ## Assistant
2259
2260 Hey!
2261 "}
2262 );
2263 });
2264
2265 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2266 fake_model
2267 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2268 let events = events.collect::<Vec<_>>().await;
2269 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2270 thread.read_with(cx, |thread, _| {
2271 assert_eq!(thread.to_markdown(), "");
2272 });
2273}
2274
2275#[gpui::test]
2276async fn test_truncate_first_message(cx: &mut TestAppContext) {
2277 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2278 let fake_model = model.as_fake();
2279
2280 let message_id = UserMessageId::new();
2281 thread
2282 .update(cx, |thread, cx| {
2283 thread.send(message_id.clone(), ["Hello"], cx)
2284 })
2285 .unwrap();
2286 cx.run_until_parked();
2287 thread.read_with(cx, |thread, _| {
2288 assert_eq!(
2289 thread.to_markdown(),
2290 indoc! {"
2291 ## User
2292
2293 Hello
2294 "}
2295 );
2296 assert_eq!(thread.latest_token_usage(), None);
2297 });
2298
2299 fake_model.send_last_completion_stream_text_chunk("Hey!");
2300 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2301 language_model::TokenUsage {
2302 input_tokens: 32_000,
2303 output_tokens: 16_000,
2304 cache_creation_input_tokens: 0,
2305 cache_read_input_tokens: 0,
2306 },
2307 ));
2308 cx.run_until_parked();
2309 thread.read_with(cx, |thread, _| {
2310 assert_eq!(
2311 thread.to_markdown(),
2312 indoc! {"
2313 ## User
2314
2315 Hello
2316
2317 ## Assistant
2318
2319 Hey!
2320 "}
2321 );
2322 assert_eq!(
2323 thread.latest_token_usage(),
2324 Some(acp_thread::TokenUsage {
2325 used_tokens: 32_000 + 16_000,
2326 max_tokens: 1_000_000,
2327 input_tokens: 32_000,
2328 output_tokens: 16_000,
2329 })
2330 );
2331 });
2332
2333 thread
2334 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2335 .unwrap();
2336 cx.run_until_parked();
2337 thread.read_with(cx, |thread, _| {
2338 assert_eq!(thread.to_markdown(), "");
2339 assert_eq!(thread.latest_token_usage(), None);
2340 });
2341
2342 // Ensure we can still send a new message after truncation.
2343 thread
2344 .update(cx, |thread, cx| {
2345 thread.send(UserMessageId::new(), ["Hi"], cx)
2346 })
2347 .unwrap();
2348 thread.update(cx, |thread, _cx| {
2349 assert_eq!(
2350 thread.to_markdown(),
2351 indoc! {"
2352 ## User
2353
2354 Hi
2355 "}
2356 );
2357 });
2358 cx.run_until_parked();
2359 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2360 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2361 language_model::TokenUsage {
2362 input_tokens: 40_000,
2363 output_tokens: 20_000,
2364 cache_creation_input_tokens: 0,
2365 cache_read_input_tokens: 0,
2366 },
2367 ));
2368 cx.run_until_parked();
2369 thread.read_with(cx, |thread, _| {
2370 assert_eq!(
2371 thread.to_markdown(),
2372 indoc! {"
2373 ## User
2374
2375 Hi
2376
2377 ## Assistant
2378
2379 Ahoy!
2380 "}
2381 );
2382
2383 assert_eq!(
2384 thread.latest_token_usage(),
2385 Some(acp_thread::TokenUsage {
2386 used_tokens: 40_000 + 20_000,
2387 max_tokens: 1_000_000,
2388 input_tokens: 40_000,
2389 output_tokens: 20_000,
2390 })
2391 );
2392 });
2393}
2394
2395#[gpui::test]
2396async fn test_truncate_second_message(cx: &mut TestAppContext) {
2397 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2398 let fake_model = model.as_fake();
2399
2400 thread
2401 .update(cx, |thread, cx| {
2402 thread.send(UserMessageId::new(), ["Message 1"], cx)
2403 })
2404 .unwrap();
2405 cx.run_until_parked();
2406 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2407 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2408 language_model::TokenUsage {
2409 input_tokens: 32_000,
2410 output_tokens: 16_000,
2411 cache_creation_input_tokens: 0,
2412 cache_read_input_tokens: 0,
2413 },
2414 ));
2415 fake_model.end_last_completion_stream();
2416 cx.run_until_parked();
2417
2418 let assert_first_message_state = |cx: &mut TestAppContext| {
2419 thread.clone().read_with(cx, |thread, _| {
2420 assert_eq!(
2421 thread.to_markdown(),
2422 indoc! {"
2423 ## User
2424
2425 Message 1
2426
2427 ## Assistant
2428
2429 Message 1 response
2430 "}
2431 );
2432
2433 assert_eq!(
2434 thread.latest_token_usage(),
2435 Some(acp_thread::TokenUsage {
2436 used_tokens: 32_000 + 16_000,
2437 max_tokens: 1_000_000,
2438 input_tokens: 32_000,
2439 output_tokens: 16_000,
2440 })
2441 );
2442 });
2443 };
2444
2445 assert_first_message_state(cx);
2446
2447 let second_message_id = UserMessageId::new();
2448 thread
2449 .update(cx, |thread, cx| {
2450 thread.send(second_message_id.clone(), ["Message 2"], cx)
2451 })
2452 .unwrap();
2453 cx.run_until_parked();
2454
2455 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2456 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2457 language_model::TokenUsage {
2458 input_tokens: 40_000,
2459 output_tokens: 20_000,
2460 cache_creation_input_tokens: 0,
2461 cache_read_input_tokens: 0,
2462 },
2463 ));
2464 fake_model.end_last_completion_stream();
2465 cx.run_until_parked();
2466
2467 thread.read_with(cx, |thread, _| {
2468 assert_eq!(
2469 thread.to_markdown(),
2470 indoc! {"
2471 ## User
2472
2473 Message 1
2474
2475 ## Assistant
2476
2477 Message 1 response
2478
2479 ## User
2480
2481 Message 2
2482
2483 ## Assistant
2484
2485 Message 2 response
2486 "}
2487 );
2488
2489 assert_eq!(
2490 thread.latest_token_usage(),
2491 Some(acp_thread::TokenUsage {
2492 used_tokens: 40_000 + 20_000,
2493 max_tokens: 1_000_000,
2494 input_tokens: 40_000,
2495 output_tokens: 20_000,
2496 })
2497 );
2498 });
2499
2500 thread
2501 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2502 .unwrap();
2503 cx.run_until_parked();
2504
2505 assert_first_message_state(cx);
2506}
2507
2508#[gpui::test]
2509async fn test_title_generation(cx: &mut TestAppContext) {
2510 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2511 let fake_model = model.as_fake();
2512
2513 let summary_model = Arc::new(FakeLanguageModel::default());
2514 thread.update(cx, |thread, cx| {
2515 thread.set_summarization_model(Some(summary_model.clone()), cx)
2516 });
2517
2518 let send = thread
2519 .update(cx, |thread, cx| {
2520 thread.send(UserMessageId::new(), ["Hello"], cx)
2521 })
2522 .unwrap();
2523 cx.run_until_parked();
2524
2525 fake_model.send_last_completion_stream_text_chunk("Hey!");
2526 fake_model.end_last_completion_stream();
2527 cx.run_until_parked();
2528 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2529
2530 // Ensure the summary model has been invoked to generate a title.
2531 summary_model.send_last_completion_stream_text_chunk("Hello ");
2532 summary_model.send_last_completion_stream_text_chunk("world\nG");
2533 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2534 summary_model.end_last_completion_stream();
2535 send.collect::<Vec<_>>().await;
2536 cx.run_until_parked();
2537 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2538
2539 // Send another message, ensuring no title is generated this time.
2540 let send = thread
2541 .update(cx, |thread, cx| {
2542 thread.send(UserMessageId::new(), ["Hello again"], cx)
2543 })
2544 .unwrap();
2545 cx.run_until_parked();
2546 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2547 fake_model.end_last_completion_stream();
2548 cx.run_until_parked();
2549 assert_eq!(summary_model.pending_completions(), Vec::new());
2550 send.collect::<Vec<_>>().await;
2551 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2552}
2553
2554#[gpui::test]
2555async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2556 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2557 let fake_model = model.as_fake();
2558
2559 let _events = thread
2560 .update(cx, |thread, cx| {
2561 thread.add_tool(ToolRequiringPermission);
2562 thread.add_tool(EchoTool);
2563 thread.send(UserMessageId::new(), ["Hey!"], cx)
2564 })
2565 .unwrap();
2566 cx.run_until_parked();
2567
2568 let permission_tool_use = LanguageModelToolUse {
2569 id: "tool_id_1".into(),
2570 name: ToolRequiringPermission::name().into(),
2571 raw_input: "{}".into(),
2572 input: json!({}),
2573 is_input_complete: true,
2574 thought_signature: None,
2575 };
2576 let echo_tool_use = LanguageModelToolUse {
2577 id: "tool_id_2".into(),
2578 name: EchoTool::name().into(),
2579 raw_input: json!({"text": "test"}).to_string(),
2580 input: json!({"text": "test"}),
2581 is_input_complete: true,
2582 thought_signature: None,
2583 };
2584 fake_model.send_last_completion_stream_text_chunk("Hi!");
2585 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2586 permission_tool_use,
2587 ));
2588 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2589 echo_tool_use.clone(),
2590 ));
2591 fake_model.end_last_completion_stream();
2592 cx.run_until_parked();
2593
2594 // Ensure pending tools are skipped when building a request.
2595 let request = thread
2596 .read_with(cx, |thread, cx| {
2597 thread.build_completion_request(CompletionIntent::EditFile, cx)
2598 })
2599 .unwrap();
2600 assert_eq!(
2601 request.messages[1..],
2602 vec![
2603 LanguageModelRequestMessage {
2604 role: Role::User,
2605 content: vec!["Hey!".into()],
2606 cache: true,
2607 reasoning_details: None,
2608 },
2609 LanguageModelRequestMessage {
2610 role: Role::Assistant,
2611 content: vec![
2612 MessageContent::Text("Hi!".into()),
2613 MessageContent::ToolUse(echo_tool_use.clone())
2614 ],
2615 cache: false,
2616 reasoning_details: None,
2617 },
2618 LanguageModelRequestMessage {
2619 role: Role::User,
2620 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2621 tool_use_id: echo_tool_use.id.clone(),
2622 tool_name: echo_tool_use.name,
2623 is_error: false,
2624 content: "test".into(),
2625 output: Some("test".into())
2626 })],
2627 cache: false,
2628 reasoning_details: None,
2629 },
2630 ],
2631 );
2632}
2633
2634#[gpui::test]
2635async fn test_agent_connection(cx: &mut TestAppContext) {
2636 cx.update(settings::init);
2637 let templates = Templates::new();
2638
2639 // Initialize language model system with test provider
2640 cx.update(|cx| {
2641 gpui_tokio::init(cx);
2642
2643 let http_client = FakeHttpClient::with_404_response();
2644 let clock = Arc::new(clock::FakeSystemClock::new());
2645 let client = Client::new(clock, http_client, cx);
2646 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2647 language_model::init(client.clone(), cx);
2648 language_models::init(user_store, client.clone(), cx);
2649 LanguageModelRegistry::test(cx);
2650 });
2651 cx.executor().forbid_parking();
2652
2653 // Create a project for new_thread
2654 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2655 fake_fs.insert_tree(path!("/test"), json!({})).await;
2656 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2657 let cwd = Path::new("/test");
2658 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2659
2660 // Create agent and connection
2661 let agent = NativeAgent::new(
2662 project.clone(),
2663 thread_store,
2664 templates.clone(),
2665 None,
2666 fake_fs.clone(),
2667 &mut cx.to_async(),
2668 )
2669 .await
2670 .unwrap();
2671 let connection = NativeAgentConnection(agent.clone());
2672
2673 // Create a thread using new_thread
2674 let connection_rc = Rc::new(connection.clone());
2675 let acp_thread = cx
2676 .update(|cx| connection_rc.new_thread(project, cwd, cx))
2677 .await
2678 .expect("new_thread should succeed");
2679
2680 // Get the session_id from the AcpThread
2681 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2682
2683 // Test model_selector returns Some
2684 let selector_opt = connection.model_selector(&session_id);
2685 assert!(
2686 selector_opt.is_some(),
2687 "agent should always support ModelSelector"
2688 );
2689 let selector = selector_opt.unwrap();
2690
2691 // Test list_models
2692 let listed_models = cx
2693 .update(|cx| selector.list_models(cx))
2694 .await
2695 .expect("list_models should succeed");
2696 let AgentModelList::Grouped(listed_models) = listed_models else {
2697 panic!("Unexpected model list type");
2698 };
2699 assert!(!listed_models.is_empty(), "should have at least one model");
2700 assert_eq!(
2701 listed_models[&AgentModelGroupName("Fake".into())][0]
2702 .id
2703 .0
2704 .as_ref(),
2705 "fake/fake"
2706 );
2707
2708 // Test selected_model returns the default
2709 let model = cx
2710 .update(|cx| selector.selected_model(cx))
2711 .await
2712 .expect("selected_model should succeed");
2713 let model = cx
2714 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2715 .unwrap();
2716 let model = model.as_fake();
2717 assert_eq!(model.id().0, "fake", "should return default model");
2718
2719 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2720 cx.run_until_parked();
2721 model.send_last_completion_stream_text_chunk("def");
2722 cx.run_until_parked();
2723 acp_thread.read_with(cx, |thread, cx| {
2724 assert_eq!(
2725 thread.to_markdown(cx),
2726 indoc! {"
2727 ## User
2728
2729 abc
2730
2731 ## Assistant
2732
2733 def
2734
2735 "}
2736 )
2737 });
2738
2739 // Test cancel
2740 cx.update(|cx| connection.cancel(&session_id, cx));
2741 request.await.expect("prompt should fail gracefully");
2742
2743 // Ensure that dropping the ACP thread causes the native thread to be
2744 // dropped as well.
2745 cx.update(|_| drop(acp_thread));
2746 let result = cx
2747 .update(|cx| {
2748 connection.prompt(
2749 Some(acp_thread::UserMessageId::new()),
2750 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2751 cx,
2752 )
2753 })
2754 .await;
2755 assert_eq!(
2756 result.as_ref().unwrap_err().to_string(),
2757 "Session not found",
2758 "unexpected result: {:?}",
2759 result
2760 );
2761}
2762
2763#[gpui::test]
2764async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2765 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2766 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2767 let fake_model = model.as_fake();
2768
2769 let mut events = thread
2770 .update(cx, |thread, cx| {
2771 thread.send(UserMessageId::new(), ["Think"], cx)
2772 })
2773 .unwrap();
2774 cx.run_until_parked();
2775
2776 // Simulate streaming partial input.
2777 let input = json!({});
2778 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2779 LanguageModelToolUse {
2780 id: "1".into(),
2781 name: ThinkingTool::name().into(),
2782 raw_input: input.to_string(),
2783 input,
2784 is_input_complete: false,
2785 thought_signature: None,
2786 },
2787 ));
2788
2789 // Input streaming completed
2790 let input = json!({ "content": "Thinking hard!" });
2791 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2792 LanguageModelToolUse {
2793 id: "1".into(),
2794 name: "thinking".into(),
2795 raw_input: input.to_string(),
2796 input,
2797 is_input_complete: true,
2798 thought_signature: None,
2799 },
2800 ));
2801 fake_model.end_last_completion_stream();
2802 cx.run_until_parked();
2803
2804 let tool_call = expect_tool_call(&mut events).await;
2805 assert_eq!(
2806 tool_call,
2807 acp::ToolCall::new("1", "Thinking")
2808 .kind(acp::ToolKind::Think)
2809 .raw_input(json!({}))
2810 .meta(acp::Meta::from_iter([(
2811 "tool_name".into(),
2812 "thinking".into()
2813 )]))
2814 );
2815 let update = expect_tool_call_update_fields(&mut events).await;
2816 assert_eq!(
2817 update,
2818 acp::ToolCallUpdate::new(
2819 "1",
2820 acp::ToolCallUpdateFields::new()
2821 .title("Thinking")
2822 .kind(acp::ToolKind::Think)
2823 .raw_input(json!({ "content": "Thinking hard!"}))
2824 )
2825 );
2826 let update = expect_tool_call_update_fields(&mut events).await;
2827 assert_eq!(
2828 update,
2829 acp::ToolCallUpdate::new(
2830 "1",
2831 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
2832 )
2833 );
2834 let update = expect_tool_call_update_fields(&mut events).await;
2835 assert_eq!(
2836 update,
2837 acp::ToolCallUpdate::new(
2838 "1",
2839 acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
2840 )
2841 );
2842 let update = expect_tool_call_update_fields(&mut events).await;
2843 assert_eq!(
2844 update,
2845 acp::ToolCallUpdate::new(
2846 "1",
2847 acp::ToolCallUpdateFields::new()
2848 .status(acp::ToolCallStatus::Completed)
2849 .raw_output("Finished thinking.")
2850 )
2851 );
2852}
2853
2854#[gpui::test]
2855async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2856 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2857 let fake_model = model.as_fake();
2858
2859 let mut events = thread
2860 .update(cx, |thread, cx| {
2861 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2862 thread.send(UserMessageId::new(), ["Hello!"], cx)
2863 })
2864 .unwrap();
2865 cx.run_until_parked();
2866
2867 fake_model.send_last_completion_stream_text_chunk("Hey!");
2868 fake_model.end_last_completion_stream();
2869
2870 let mut retry_events = Vec::new();
2871 while let Some(Ok(event)) = events.next().await {
2872 match event {
2873 ThreadEvent::Retry(retry_status) => {
2874 retry_events.push(retry_status);
2875 }
2876 ThreadEvent::Stop(..) => break,
2877 _ => {}
2878 }
2879 }
2880
2881 assert_eq!(retry_events.len(), 0);
2882 thread.read_with(cx, |thread, _cx| {
2883 assert_eq!(
2884 thread.to_markdown(),
2885 indoc! {"
2886 ## User
2887
2888 Hello!
2889
2890 ## Assistant
2891
2892 Hey!
2893 "}
2894 )
2895 });
2896}
2897
2898#[gpui::test]
2899async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2900 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2901 let fake_model = model.as_fake();
2902
2903 let mut events = thread
2904 .update(cx, |thread, cx| {
2905 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2906 thread.send(UserMessageId::new(), ["Hello!"], cx)
2907 })
2908 .unwrap();
2909 cx.run_until_parked();
2910
2911 fake_model.send_last_completion_stream_text_chunk("Hey,");
2912 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2913 provider: LanguageModelProviderName::new("Anthropic"),
2914 retry_after: Some(Duration::from_secs(3)),
2915 });
2916 fake_model.end_last_completion_stream();
2917
2918 cx.executor().advance_clock(Duration::from_secs(3));
2919 cx.run_until_parked();
2920
2921 fake_model.send_last_completion_stream_text_chunk("there!");
2922 fake_model.end_last_completion_stream();
2923 cx.run_until_parked();
2924
2925 let mut retry_events = Vec::new();
2926 while let Some(Ok(event)) = events.next().await {
2927 match event {
2928 ThreadEvent::Retry(retry_status) => {
2929 retry_events.push(retry_status);
2930 }
2931 ThreadEvent::Stop(..) => break,
2932 _ => {}
2933 }
2934 }
2935
2936 assert_eq!(retry_events.len(), 1);
2937 assert!(matches!(
2938 retry_events[0],
2939 acp_thread::RetryStatus { attempt: 1, .. }
2940 ));
2941 thread.read_with(cx, |thread, _cx| {
2942 assert_eq!(
2943 thread.to_markdown(),
2944 indoc! {"
2945 ## User
2946
2947 Hello!
2948
2949 ## Assistant
2950
2951 Hey,
2952
2953 [resume]
2954
2955 ## Assistant
2956
2957 there!
2958 "}
2959 )
2960 });
2961}
2962
2963#[gpui::test]
2964async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2965 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2966 let fake_model = model.as_fake();
2967
2968 let events = thread
2969 .update(cx, |thread, cx| {
2970 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2971 thread.add_tool(EchoTool);
2972 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2973 })
2974 .unwrap();
2975 cx.run_until_parked();
2976
2977 let tool_use_1 = LanguageModelToolUse {
2978 id: "tool_1".into(),
2979 name: EchoTool::name().into(),
2980 raw_input: json!({"text": "test"}).to_string(),
2981 input: json!({"text": "test"}),
2982 is_input_complete: true,
2983 thought_signature: None,
2984 };
2985 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2986 tool_use_1.clone(),
2987 ));
2988 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2989 provider: LanguageModelProviderName::new("Anthropic"),
2990 retry_after: Some(Duration::from_secs(3)),
2991 });
2992 fake_model.end_last_completion_stream();
2993
2994 cx.executor().advance_clock(Duration::from_secs(3));
2995 let completion = fake_model.pending_completions().pop().unwrap();
2996 assert_eq!(
2997 completion.messages[1..],
2998 vec![
2999 LanguageModelRequestMessage {
3000 role: Role::User,
3001 content: vec!["Call the echo tool!".into()],
3002 cache: false,
3003 reasoning_details: None,
3004 },
3005 LanguageModelRequestMessage {
3006 role: Role::Assistant,
3007 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3008 cache: false,
3009 reasoning_details: None,
3010 },
3011 LanguageModelRequestMessage {
3012 role: Role::User,
3013 content: vec![language_model::MessageContent::ToolResult(
3014 LanguageModelToolResult {
3015 tool_use_id: tool_use_1.id.clone(),
3016 tool_name: tool_use_1.name.clone(),
3017 is_error: false,
3018 content: "test".into(),
3019 output: Some("test".into())
3020 }
3021 )],
3022 cache: true,
3023 reasoning_details: None,
3024 },
3025 ]
3026 );
3027
3028 fake_model.send_last_completion_stream_text_chunk("Done");
3029 fake_model.end_last_completion_stream();
3030 cx.run_until_parked();
3031 events.collect::<Vec<_>>().await;
3032 thread.read_with(cx, |thread, _cx| {
3033 assert_eq!(
3034 thread.last_message(),
3035 Some(Message::Agent(AgentMessage {
3036 content: vec![AgentMessageContent::Text("Done".into())],
3037 tool_results: IndexMap::default(),
3038 reasoning_details: None,
3039 }))
3040 );
3041 })
3042}
3043
3044#[gpui::test]
3045async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3046 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3047 let fake_model = model.as_fake();
3048
3049 let mut events = thread
3050 .update(cx, |thread, cx| {
3051 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3052 thread.send(UserMessageId::new(), ["Hello!"], cx)
3053 })
3054 .unwrap();
3055 cx.run_until_parked();
3056
3057 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3058 fake_model.send_last_completion_stream_error(
3059 LanguageModelCompletionError::ServerOverloaded {
3060 provider: LanguageModelProviderName::new("Anthropic"),
3061 retry_after: Some(Duration::from_secs(3)),
3062 },
3063 );
3064 fake_model.end_last_completion_stream();
3065 cx.executor().advance_clock(Duration::from_secs(3));
3066 cx.run_until_parked();
3067 }
3068
3069 let mut errors = Vec::new();
3070 let mut retry_events = Vec::new();
3071 while let Some(event) = events.next().await {
3072 match event {
3073 Ok(ThreadEvent::Retry(retry_status)) => {
3074 retry_events.push(retry_status);
3075 }
3076 Ok(ThreadEvent::Stop(..)) => break,
3077 Err(error) => errors.push(error),
3078 _ => {}
3079 }
3080 }
3081
3082 assert_eq!(
3083 retry_events.len(),
3084 crate::thread::MAX_RETRY_ATTEMPTS as usize
3085 );
3086 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3087 assert_eq!(retry_events[i].attempt, i + 1);
3088 }
3089 assert_eq!(errors.len(), 1);
3090 let error = errors[0]
3091 .downcast_ref::<LanguageModelCompletionError>()
3092 .unwrap();
3093 assert!(matches!(
3094 error,
3095 LanguageModelCompletionError::ServerOverloaded { .. }
3096 ));
3097}
3098
3099/// Filters out the stop events for asserting against in tests
3100fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3101 result_events
3102 .into_iter()
3103 .filter_map(|event| match event.unwrap() {
3104 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3105 _ => None,
3106 })
3107 .collect()
3108}
3109
3110struct ThreadTest {
3111 model: Arc<dyn LanguageModel>,
3112 thread: Entity<Thread>,
3113 project_context: Entity<ProjectContext>,
3114 context_server_store: Entity<ContextServerStore>,
3115 fs: Arc<FakeFs>,
3116}
3117
3118enum TestModel {
3119 Sonnet4,
3120 Fake,
3121}
3122
3123impl TestModel {
3124 fn id(&self) -> LanguageModelId {
3125 match self {
3126 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3127 TestModel::Fake => unreachable!(),
3128 }
3129 }
3130}
3131
3132async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3133 cx.executor().allow_parking();
3134
3135 let fs = FakeFs::new(cx.background_executor.clone());
3136 fs.create_dir(paths::settings_file().parent().unwrap())
3137 .await
3138 .unwrap();
3139 fs.insert_file(
3140 paths::settings_file(),
3141 json!({
3142 "agent": {
3143 "default_profile": "test-profile",
3144 "profiles": {
3145 "test-profile": {
3146 "name": "Test Profile",
3147 "tools": {
3148 EchoTool::name(): true,
3149 DelayTool::name(): true,
3150 WordListTool::name(): true,
3151 ToolRequiringPermission::name(): true,
3152 InfiniteTool::name(): true,
3153 CancellationAwareTool::name(): true,
3154 ThinkingTool::name(): true,
3155 "terminal": true,
3156 }
3157 }
3158 }
3159 }
3160 })
3161 .to_string()
3162 .into_bytes(),
3163 )
3164 .await;
3165
3166 cx.update(|cx| {
3167 settings::init(cx);
3168
3169 match model {
3170 TestModel::Fake => {}
3171 TestModel::Sonnet4 => {
3172 gpui_tokio::init(cx);
3173 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3174 cx.set_http_client(Arc::new(http_client));
3175 let client = Client::production(cx);
3176 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3177 language_model::init(client.clone(), cx);
3178 language_models::init(user_store, client.clone(), cx);
3179 }
3180 };
3181
3182 watch_settings(fs.clone(), cx);
3183 });
3184
3185 let templates = Templates::new();
3186
3187 fs.insert_tree(path!("/test"), json!({})).await;
3188 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3189
3190 let model = cx
3191 .update(|cx| {
3192 if let TestModel::Fake = model {
3193 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3194 } else {
3195 let model_id = model.id();
3196 let models = LanguageModelRegistry::read_global(cx);
3197 let model = models
3198 .available_models(cx)
3199 .find(|model| model.id() == model_id)
3200 .unwrap();
3201
3202 let provider = models.provider(&model.provider_id()).unwrap();
3203 let authenticated = provider.authenticate(cx);
3204
3205 cx.spawn(async move |_cx| {
3206 authenticated.await.unwrap();
3207 model
3208 })
3209 }
3210 })
3211 .await;
3212
3213 let project_context = cx.new(|_cx| ProjectContext::default());
3214 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3215 let context_server_registry =
3216 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3217 let thread = cx.new(|cx| {
3218 Thread::new(
3219 project,
3220 project_context.clone(),
3221 context_server_registry,
3222 templates,
3223 Some(model.clone()),
3224 cx,
3225 )
3226 });
3227 ThreadTest {
3228 model,
3229 thread,
3230 project_context,
3231 context_server_store,
3232 fs,
3233 }
3234}
3235
3236#[cfg(test)]
3237#[ctor::ctor]
3238fn init_logger() {
3239 if std::env::var("RUST_LOG").is_ok() {
3240 env_logger::init();
3241 }
3242}
3243
3244fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3245 let fs = fs.clone();
3246 cx.spawn({
3247 async move |cx| {
3248 let mut new_settings_content_rx = settings::watch_config_file(
3249 cx.background_executor(),
3250 fs,
3251 paths::settings_file().clone(),
3252 );
3253
3254 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3255 cx.update(|cx| {
3256 SettingsStore::update_global(cx, |settings, cx| {
3257 settings.set_user_settings(&new_settings_content, cx)
3258 })
3259 })
3260 .ok();
3261 }
3262 }
3263 })
3264 .detach();
3265}
3266
3267fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3268 completion
3269 .tools
3270 .iter()
3271 .map(|tool| tool.name.clone())
3272 .collect()
3273}
3274
3275fn setup_context_server(
3276 name: &'static str,
3277 tools: Vec<context_server::types::Tool>,
3278 context_server_store: &Entity<ContextServerStore>,
3279 cx: &mut TestAppContext,
3280) -> mpsc::UnboundedReceiver<(
3281 context_server::types::CallToolParams,
3282 oneshot::Sender<context_server::types::CallToolResponse>,
3283)> {
3284 cx.update(|cx| {
3285 let mut settings = ProjectSettings::get_global(cx).clone();
3286 settings.context_servers.insert(
3287 name.into(),
3288 project::project_settings::ContextServerSettings::Stdio {
3289 enabled: true,
3290 remote: false,
3291 command: ContextServerCommand {
3292 path: "somebinary".into(),
3293 args: Vec::new(),
3294 env: None,
3295 timeout: None,
3296 },
3297 },
3298 );
3299 ProjectSettings::override_global(settings, cx);
3300 });
3301
3302 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3303 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3304 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3305 context_server::types::InitializeResponse {
3306 protocol_version: context_server::types::ProtocolVersion(
3307 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3308 ),
3309 server_info: context_server::types::Implementation {
3310 name: name.into(),
3311 version: "1.0.0".to_string(),
3312 },
3313 capabilities: context_server::types::ServerCapabilities {
3314 tools: Some(context_server::types::ToolsCapabilities {
3315 list_changed: Some(true),
3316 }),
3317 ..Default::default()
3318 },
3319 meta: None,
3320 }
3321 })
3322 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3323 let tools = tools.clone();
3324 async move {
3325 context_server::types::ListToolsResponse {
3326 tools,
3327 next_cursor: None,
3328 meta: None,
3329 }
3330 }
3331 })
3332 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3333 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3334 async move {
3335 let (response_tx, response_rx) = oneshot::channel();
3336 mcp_tool_calls_tx
3337 .unbounded_send((params, response_tx))
3338 .unwrap();
3339 response_rx.await.unwrap()
3340 }
3341 });
3342 context_server_store.update(cx, |store, cx| {
3343 store.start_server(
3344 Arc::new(ContextServer::new(
3345 ContextServerId(name.into()),
3346 Arc::new(fake_transport),
3347 )),
3348 cx,
3349 );
3350 });
3351 cx.run_until_parked();
3352 mcp_tool_calls_rx
3353}
3354
3355#[gpui::test]
3356async fn test_tokens_before_message(cx: &mut TestAppContext) {
3357 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3358 let fake_model = model.as_fake();
3359
3360 // First message
3361 let message_1_id = UserMessageId::new();
3362 thread
3363 .update(cx, |thread, cx| {
3364 thread.send(message_1_id.clone(), ["First message"], cx)
3365 })
3366 .unwrap();
3367 cx.run_until_parked();
3368
3369 // Before any response, tokens_before_message should return None for first message
3370 thread.read_with(cx, |thread, _| {
3371 assert_eq!(
3372 thread.tokens_before_message(&message_1_id),
3373 None,
3374 "First message should have no tokens before it"
3375 );
3376 });
3377
3378 // Complete first message with usage
3379 fake_model.send_last_completion_stream_text_chunk("Response 1");
3380 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3381 language_model::TokenUsage {
3382 input_tokens: 100,
3383 output_tokens: 50,
3384 cache_creation_input_tokens: 0,
3385 cache_read_input_tokens: 0,
3386 },
3387 ));
3388 fake_model.end_last_completion_stream();
3389 cx.run_until_parked();
3390
3391 // First message still has no tokens before it
3392 thread.read_with(cx, |thread, _| {
3393 assert_eq!(
3394 thread.tokens_before_message(&message_1_id),
3395 None,
3396 "First message should still have no tokens before it after response"
3397 );
3398 });
3399
3400 // Second message
3401 let message_2_id = UserMessageId::new();
3402 thread
3403 .update(cx, |thread, cx| {
3404 thread.send(message_2_id.clone(), ["Second message"], cx)
3405 })
3406 .unwrap();
3407 cx.run_until_parked();
3408
3409 // Second message should have first message's input tokens before it
3410 thread.read_with(cx, |thread, _| {
3411 assert_eq!(
3412 thread.tokens_before_message(&message_2_id),
3413 Some(100),
3414 "Second message should have 100 tokens before it (from first request)"
3415 );
3416 });
3417
3418 // Complete second message
3419 fake_model.send_last_completion_stream_text_chunk("Response 2");
3420 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3421 language_model::TokenUsage {
3422 input_tokens: 250, // Total for this request (includes previous context)
3423 output_tokens: 75,
3424 cache_creation_input_tokens: 0,
3425 cache_read_input_tokens: 0,
3426 },
3427 ));
3428 fake_model.end_last_completion_stream();
3429 cx.run_until_parked();
3430
3431 // Third message
3432 let message_3_id = UserMessageId::new();
3433 thread
3434 .update(cx, |thread, cx| {
3435 thread.send(message_3_id.clone(), ["Third message"], cx)
3436 })
3437 .unwrap();
3438 cx.run_until_parked();
3439
3440 // Third message should have second message's input tokens (250) before it
3441 thread.read_with(cx, |thread, _| {
3442 assert_eq!(
3443 thread.tokens_before_message(&message_3_id),
3444 Some(250),
3445 "Third message should have 250 tokens before it (from second request)"
3446 );
3447 // Second message should still have 100
3448 assert_eq!(
3449 thread.tokens_before_message(&message_2_id),
3450 Some(100),
3451 "Second message should still have 100 tokens before it"
3452 );
3453 // First message still has none
3454 assert_eq!(
3455 thread.tokens_before_message(&message_1_id),
3456 None,
3457 "First message should still have no tokens before it"
3458 );
3459 });
3460}
3461
3462#[gpui::test]
3463async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3464 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3465 let fake_model = model.as_fake();
3466
3467 // Set up three messages with responses
3468 let message_1_id = UserMessageId::new();
3469 thread
3470 .update(cx, |thread, cx| {
3471 thread.send(message_1_id.clone(), ["Message 1"], cx)
3472 })
3473 .unwrap();
3474 cx.run_until_parked();
3475 fake_model.send_last_completion_stream_text_chunk("Response 1");
3476 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3477 language_model::TokenUsage {
3478 input_tokens: 100,
3479 output_tokens: 50,
3480 cache_creation_input_tokens: 0,
3481 cache_read_input_tokens: 0,
3482 },
3483 ));
3484 fake_model.end_last_completion_stream();
3485 cx.run_until_parked();
3486
3487 let message_2_id = UserMessageId::new();
3488 thread
3489 .update(cx, |thread, cx| {
3490 thread.send(message_2_id.clone(), ["Message 2"], cx)
3491 })
3492 .unwrap();
3493 cx.run_until_parked();
3494 fake_model.send_last_completion_stream_text_chunk("Response 2");
3495 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3496 language_model::TokenUsage {
3497 input_tokens: 250,
3498 output_tokens: 75,
3499 cache_creation_input_tokens: 0,
3500 cache_read_input_tokens: 0,
3501 },
3502 ));
3503 fake_model.end_last_completion_stream();
3504 cx.run_until_parked();
3505
3506 // Verify initial state
3507 thread.read_with(cx, |thread, _| {
3508 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3509 });
3510
3511 // Truncate at message 2 (removes message 2 and everything after)
3512 thread
3513 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3514 .unwrap();
3515 cx.run_until_parked();
3516
3517 // After truncation, message_2_id no longer exists, so lookup should return None
3518 thread.read_with(cx, |thread, _| {
3519 assert_eq!(
3520 thread.tokens_before_message(&message_2_id),
3521 None,
3522 "After truncation, message 2 no longer exists"
3523 );
3524 // Message 1 still exists but has no tokens before it
3525 assert_eq!(
3526 thread.tokens_before_message(&message_1_id),
3527 None,
3528 "First message still has no tokens before it"
3529 );
3530 });
3531}
3532
3533#[gpui::test]
3534async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3535 init_test(cx);
3536
3537 let fs = FakeFs::new(cx.executor());
3538 fs.insert_tree("/root", json!({})).await;
3539 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3540
3541 // Test 1: Deny rule blocks command
3542 {
3543 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3544 let environment = Rc::new(FakeThreadEnvironment {
3545 handle: handle.clone(),
3546 });
3547
3548 cx.update(|cx| {
3549 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3550 settings.tool_permissions.tools.insert(
3551 "terminal".into(),
3552 agent_settings::ToolRules {
3553 default_mode: settings::ToolPermissionMode::Confirm,
3554 always_allow: vec![],
3555 always_deny: vec![
3556 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3557 ],
3558 always_confirm: vec![],
3559 invalid_patterns: vec![],
3560 },
3561 );
3562 agent_settings::AgentSettings::override_global(settings, cx);
3563 });
3564
3565 #[allow(clippy::arc_with_non_send_sync)]
3566 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3567 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3568
3569 let task = cx.update(|cx| {
3570 tool.run(
3571 crate::TerminalToolInput {
3572 command: "rm -rf /".to_string(),
3573 cd: ".".to_string(),
3574 timeout_ms: None,
3575 },
3576 event_stream,
3577 cx,
3578 )
3579 });
3580
3581 let result = task.await;
3582 assert!(
3583 result.is_err(),
3584 "expected command to be blocked by deny rule"
3585 );
3586 assert!(
3587 result.unwrap_err().to_string().contains("blocked"),
3588 "error should mention the command was blocked"
3589 );
3590 }
3591
3592 // Test 2: Allow rule skips confirmation (and overrides default_mode: Deny)
3593 {
3594 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3595 let environment = Rc::new(FakeThreadEnvironment {
3596 handle: handle.clone(),
3597 });
3598
3599 cx.update(|cx| {
3600 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3601 settings.always_allow_tool_actions = false;
3602 settings.tool_permissions.tools.insert(
3603 "terminal".into(),
3604 agent_settings::ToolRules {
3605 default_mode: settings::ToolPermissionMode::Deny,
3606 always_allow: vec![
3607 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
3608 ],
3609 always_deny: vec![],
3610 always_confirm: vec![],
3611 invalid_patterns: vec![],
3612 },
3613 );
3614 agent_settings::AgentSettings::override_global(settings, cx);
3615 });
3616
3617 #[allow(clippy::arc_with_non_send_sync)]
3618 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3619 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3620
3621 let task = cx.update(|cx| {
3622 tool.run(
3623 crate::TerminalToolInput {
3624 command: "echo hello".to_string(),
3625 cd: ".".to_string(),
3626 timeout_ms: None,
3627 },
3628 event_stream,
3629 cx,
3630 )
3631 });
3632
3633 let update = rx.expect_update_fields().await;
3634 assert!(
3635 update.content.iter().any(|blocks| {
3636 blocks
3637 .iter()
3638 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
3639 }),
3640 "expected terminal content (allow rule should skip confirmation and override default deny)"
3641 );
3642
3643 let result = task.await;
3644 assert!(
3645 result.is_ok(),
3646 "expected command to succeed without confirmation"
3647 );
3648 }
3649
3650 // Test 3: Confirm rule forces confirmation even with always_allow_tool_actions=true
3651 {
3652 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3653 let environment = Rc::new(FakeThreadEnvironment {
3654 handle: handle.clone(),
3655 });
3656
3657 cx.update(|cx| {
3658 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3659 settings.always_allow_tool_actions = true;
3660 settings.tool_permissions.tools.insert(
3661 "terminal".into(),
3662 agent_settings::ToolRules {
3663 default_mode: settings::ToolPermissionMode::Allow,
3664 always_allow: vec![],
3665 always_deny: vec![],
3666 always_confirm: vec![
3667 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
3668 ],
3669 invalid_patterns: vec![],
3670 },
3671 );
3672 agent_settings::AgentSettings::override_global(settings, cx);
3673 });
3674
3675 #[allow(clippy::arc_with_non_send_sync)]
3676 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3677 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3678
3679 let _task = cx.update(|cx| {
3680 tool.run(
3681 crate::TerminalToolInput {
3682 command: "sudo rm file".to_string(),
3683 cd: ".".to_string(),
3684 timeout_ms: None,
3685 },
3686 event_stream,
3687 cx,
3688 )
3689 });
3690
3691 let auth = rx.expect_authorization().await;
3692 assert!(
3693 auth.tool_call.fields.title.is_some(),
3694 "expected authorization request for sudo command despite always_allow_tool_actions=true"
3695 );
3696 }
3697
3698 // Test 4: default_mode: Deny blocks commands when no pattern matches
3699 {
3700 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3701 let environment = Rc::new(FakeThreadEnvironment {
3702 handle: handle.clone(),
3703 });
3704
3705 cx.update(|cx| {
3706 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3707 settings.always_allow_tool_actions = true;
3708 settings.tool_permissions.tools.insert(
3709 "terminal".into(),
3710 agent_settings::ToolRules {
3711 default_mode: settings::ToolPermissionMode::Deny,
3712 always_allow: vec![],
3713 always_deny: vec![],
3714 always_confirm: vec![],
3715 invalid_patterns: vec![],
3716 },
3717 );
3718 agent_settings::AgentSettings::override_global(settings, cx);
3719 });
3720
3721 #[allow(clippy::arc_with_non_send_sync)]
3722 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3723 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3724
3725 let task = cx.update(|cx| {
3726 tool.run(
3727 crate::TerminalToolInput {
3728 command: "echo hello".to_string(),
3729 cd: ".".to_string(),
3730 timeout_ms: None,
3731 },
3732 event_stream,
3733 cx,
3734 )
3735 });
3736
3737 let result = task.await;
3738 assert!(
3739 result.is_err(),
3740 "expected command to be blocked by default_mode: Deny"
3741 );
3742 assert!(
3743 result.unwrap_err().to_string().contains("disabled"),
3744 "error should mention the tool is disabled"
3745 );
3746 }
3747}
3748
3749#[gpui::test]
3750async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
3751 init_test(cx);
3752
3753 cx.update(|cx| {
3754 cx.update_flags(true, vec!["subagents".to_string()]);
3755 });
3756
3757 let fs = FakeFs::new(cx.executor());
3758 fs.insert_tree(path!("/test"), json!({})).await;
3759 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3760 let project_context = cx.new(|_cx| ProjectContext::default());
3761 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3762 let context_server_registry =
3763 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3764 let model = Arc::new(FakeLanguageModel::default());
3765
3766 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3767 let environment = Rc::new(FakeThreadEnvironment { handle });
3768
3769 let thread = cx.new(|cx| {
3770 let mut thread = Thread::new(
3771 project.clone(),
3772 project_context,
3773 context_server_registry,
3774 Templates::new(),
3775 Some(model),
3776 cx,
3777 );
3778 thread.add_default_tools(environment, cx);
3779 thread
3780 });
3781
3782 thread.read_with(cx, |thread, _| {
3783 assert!(
3784 thread.has_registered_tool("subagent"),
3785 "subagent tool should be present when feature flag is enabled"
3786 );
3787 });
3788}
3789
3790#[gpui::test]
3791async fn test_subagent_thread_inherits_parent_model(cx: &mut TestAppContext) {
3792 init_test(cx);
3793
3794 cx.update(|cx| {
3795 cx.update_flags(true, vec!["subagents".to_string()]);
3796 });
3797
3798 let fs = FakeFs::new(cx.executor());
3799 fs.insert_tree(path!("/test"), json!({})).await;
3800 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3801 let project_context = cx.new(|_cx| ProjectContext::default());
3802 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3803 let context_server_registry =
3804 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3805 let model = Arc::new(FakeLanguageModel::default());
3806
3807 let subagent_context = SubagentContext {
3808 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3809 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3810 depth: 1,
3811 summary_prompt: "Summarize".to_string(),
3812 context_low_prompt: "Context low".to_string(),
3813 };
3814
3815 let subagent = cx.new(|cx| {
3816 Thread::new_subagent(
3817 project.clone(),
3818 project_context,
3819 context_server_registry,
3820 Templates::new(),
3821 model.clone(),
3822 subagent_context,
3823 std::collections::BTreeMap::new(),
3824 cx,
3825 )
3826 });
3827
3828 subagent.read_with(cx, |thread, _| {
3829 assert!(thread.is_subagent());
3830 assert_eq!(thread.depth(), 1);
3831 assert!(thread.model().is_some());
3832 });
3833}
3834
3835#[gpui::test]
3836async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
3837 init_test(cx);
3838
3839 cx.update(|cx| {
3840 cx.update_flags(true, vec!["subagents".to_string()]);
3841 });
3842
3843 let fs = FakeFs::new(cx.executor());
3844 fs.insert_tree(path!("/test"), json!({})).await;
3845 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3846 let project_context = cx.new(|_cx| ProjectContext::default());
3847 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3848 let context_server_registry =
3849 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3850 let model = Arc::new(FakeLanguageModel::default());
3851
3852 let subagent_context = SubagentContext {
3853 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3854 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3855 depth: MAX_SUBAGENT_DEPTH,
3856 summary_prompt: "Summarize".to_string(),
3857 context_low_prompt: "Context low".to_string(),
3858 };
3859
3860 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3861 let environment = Rc::new(FakeThreadEnvironment { handle });
3862
3863 let deep_subagent = cx.new(|cx| {
3864 let mut thread = Thread::new_subagent(
3865 project.clone(),
3866 project_context,
3867 context_server_registry,
3868 Templates::new(),
3869 model.clone(),
3870 subagent_context,
3871 std::collections::BTreeMap::new(),
3872 cx,
3873 );
3874 thread.add_default_tools(environment, cx);
3875 thread
3876 });
3877
3878 deep_subagent.read_with(cx, |thread, _| {
3879 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
3880 assert!(
3881 !thread.has_registered_tool("subagent"),
3882 "subagent tool should not be present at max depth"
3883 );
3884 });
3885}
3886
3887#[gpui::test]
3888async fn test_subagent_receives_task_prompt(cx: &mut TestAppContext) {
3889 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3890 let fake_model = model.as_fake();
3891
3892 cx.update(|cx| {
3893 cx.update_flags(true, vec!["subagents".to_string()]);
3894 });
3895
3896 let subagent_context = SubagentContext {
3897 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3898 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3899 depth: 1,
3900 summary_prompt: "Summarize your work".to_string(),
3901 context_low_prompt: "Context low, wrap up".to_string(),
3902 };
3903
3904 let project = thread.read_with(cx, |t, _| t.project.clone());
3905 let project_context = cx.new(|_cx| ProjectContext::default());
3906 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3907 let context_server_registry =
3908 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3909
3910 let subagent = cx.new(|cx| {
3911 Thread::new_subagent(
3912 project.clone(),
3913 project_context,
3914 context_server_registry,
3915 Templates::new(),
3916 model.clone(),
3917 subagent_context,
3918 std::collections::BTreeMap::new(),
3919 cx,
3920 )
3921 });
3922
3923 let task_prompt = "Find all TODO comments in the codebase";
3924 subagent
3925 .update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))
3926 .unwrap();
3927 cx.run_until_parked();
3928
3929 let pending = fake_model.pending_completions();
3930 assert_eq!(pending.len(), 1, "should have one pending completion");
3931
3932 let messages = &pending[0].messages;
3933 let user_messages: Vec<_> = messages
3934 .iter()
3935 .filter(|m| m.role == language_model::Role::User)
3936 .collect();
3937 assert_eq!(user_messages.len(), 1, "should have one user message");
3938
3939 let content = &user_messages[0].content[0];
3940 assert!(
3941 content.to_str().unwrap().contains("TODO"),
3942 "task prompt should be in user message"
3943 );
3944}
3945
3946#[gpui::test]
3947async fn test_subagent_returns_summary_on_completion(cx: &mut TestAppContext) {
3948 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3949 let fake_model = model.as_fake();
3950
3951 cx.update(|cx| {
3952 cx.update_flags(true, vec!["subagents".to_string()]);
3953 });
3954
3955 let subagent_context = SubagentContext {
3956 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3957 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3958 depth: 1,
3959 summary_prompt: "Please summarize what you found".to_string(),
3960 context_low_prompt: "Context low, wrap up".to_string(),
3961 };
3962
3963 let project = thread.read_with(cx, |t, _| t.project.clone());
3964 let project_context = cx.new(|_cx| ProjectContext::default());
3965 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3966 let context_server_registry =
3967 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3968
3969 let subagent = cx.new(|cx| {
3970 Thread::new_subagent(
3971 project.clone(),
3972 project_context,
3973 context_server_registry,
3974 Templates::new(),
3975 model.clone(),
3976 subagent_context,
3977 std::collections::BTreeMap::new(),
3978 cx,
3979 )
3980 });
3981
3982 subagent
3983 .update(cx, |thread, cx| {
3984 thread.submit_user_message("Do some work", cx)
3985 })
3986 .unwrap();
3987 cx.run_until_parked();
3988
3989 fake_model.send_last_completion_stream_text_chunk("I did the work");
3990 fake_model
3991 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
3992 fake_model.end_last_completion_stream();
3993 cx.run_until_parked();
3994
3995 subagent
3996 .update(cx, |thread, cx| thread.request_final_summary(cx))
3997 .unwrap();
3998 cx.run_until_parked();
3999
4000 let pending = fake_model.pending_completions();
4001 assert!(
4002 !pending.is_empty(),
4003 "should have pending completion for summary"
4004 );
4005
4006 let messages = &pending.last().unwrap().messages;
4007 let user_messages: Vec<_> = messages
4008 .iter()
4009 .filter(|m| m.role == language_model::Role::User)
4010 .collect();
4011
4012 let last_user = user_messages.last().unwrap();
4013 assert!(
4014 last_user.content[0].to_str().unwrap().contains("summarize"),
4015 "summary prompt should be sent"
4016 );
4017}
4018
4019#[gpui::test]
4020async fn test_allowed_tools_restricts_subagent_capabilities(cx: &mut TestAppContext) {
4021 init_test(cx);
4022
4023 cx.update(|cx| {
4024 cx.update_flags(true, vec!["subagents".to_string()]);
4025 });
4026
4027 let fs = FakeFs::new(cx.executor());
4028 fs.insert_tree(path!("/test"), json!({})).await;
4029 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4030 let project_context = cx.new(|_cx| ProjectContext::default());
4031 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4032 let context_server_registry =
4033 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4034 let model = Arc::new(FakeLanguageModel::default());
4035
4036 let subagent_context = SubagentContext {
4037 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4038 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4039 depth: 1,
4040 summary_prompt: "Summarize".to_string(),
4041 context_low_prompt: "Context low".to_string(),
4042 };
4043
4044 let subagent = cx.new(|cx| {
4045 let mut thread = Thread::new_subagent(
4046 project.clone(),
4047 project_context,
4048 context_server_registry,
4049 Templates::new(),
4050 model.clone(),
4051 subagent_context,
4052 std::collections::BTreeMap::new(),
4053 cx,
4054 );
4055 thread.add_tool(EchoTool);
4056 thread.add_tool(DelayTool);
4057 thread.add_tool(WordListTool);
4058 thread
4059 });
4060
4061 subagent.read_with(cx, |thread, _| {
4062 assert!(thread.has_registered_tool("echo"));
4063 assert!(thread.has_registered_tool("delay"));
4064 assert!(thread.has_registered_tool("word_list"));
4065 });
4066
4067 let allowed: collections::HashSet<gpui::SharedString> =
4068 vec!["echo".into()].into_iter().collect();
4069
4070 subagent.update(cx, |thread, _cx| {
4071 thread.restrict_tools(&allowed);
4072 });
4073
4074 subagent.read_with(cx, |thread, _| {
4075 assert!(
4076 thread.has_registered_tool("echo"),
4077 "echo should still be available"
4078 );
4079 assert!(
4080 !thread.has_registered_tool("delay"),
4081 "delay should be removed"
4082 );
4083 assert!(
4084 !thread.has_registered_tool("word_list"),
4085 "word_list should be removed"
4086 );
4087 });
4088}
4089
4090#[gpui::test]
4091async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4092 init_test(cx);
4093
4094 cx.update(|cx| {
4095 cx.update_flags(true, vec!["subagents".to_string()]);
4096 });
4097
4098 let fs = FakeFs::new(cx.executor());
4099 fs.insert_tree(path!("/test"), json!({})).await;
4100 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4101 let project_context = cx.new(|_cx| ProjectContext::default());
4102 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4103 let context_server_registry =
4104 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4105 let model = Arc::new(FakeLanguageModel::default());
4106
4107 let parent = cx.new(|cx| {
4108 Thread::new(
4109 project.clone(),
4110 project_context.clone(),
4111 context_server_registry.clone(),
4112 Templates::new(),
4113 Some(model.clone()),
4114 cx,
4115 )
4116 });
4117
4118 let subagent_context = SubagentContext {
4119 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4120 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4121 depth: 1,
4122 summary_prompt: "Summarize".to_string(),
4123 context_low_prompt: "Context low".to_string(),
4124 };
4125
4126 let subagent = cx.new(|cx| {
4127 Thread::new_subagent(
4128 project.clone(),
4129 project_context.clone(),
4130 context_server_registry.clone(),
4131 Templates::new(),
4132 model.clone(),
4133 subagent_context,
4134 std::collections::BTreeMap::new(),
4135 cx,
4136 )
4137 });
4138
4139 parent.update(cx, |thread, _cx| {
4140 thread.register_running_subagent(subagent.downgrade());
4141 });
4142
4143 subagent
4144 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4145 .unwrap();
4146 cx.run_until_parked();
4147
4148 subagent.read_with(cx, |thread, _| {
4149 assert!(!thread.is_turn_complete(), "subagent should be running");
4150 });
4151
4152 parent.update(cx, |thread, cx| {
4153 thread.cancel(cx).detach();
4154 });
4155
4156 subagent.read_with(cx, |thread, _| {
4157 assert!(
4158 thread.is_turn_complete(),
4159 "subagent should be cancelled when parent cancels"
4160 );
4161 });
4162}
4163
4164#[gpui::test]
4165async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4166 // This test verifies that the subagent tool properly handles user cancellation
4167 // via `event_stream.cancelled_by_user()` and stops all running subagents.
4168 init_test(cx);
4169 always_allow_tools(cx);
4170
4171 cx.update(|cx| {
4172 cx.update_flags(true, vec!["subagents".to_string()]);
4173 });
4174
4175 let fs = FakeFs::new(cx.executor());
4176 fs.insert_tree(path!("/test"), json!({})).await;
4177 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4178 let project_context = cx.new(|_cx| ProjectContext::default());
4179 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4180 let context_server_registry =
4181 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4182 let model = Arc::new(FakeLanguageModel::default());
4183
4184 let parent = cx.new(|cx| {
4185 Thread::new(
4186 project.clone(),
4187 project_context.clone(),
4188 context_server_registry.clone(),
4189 Templates::new(),
4190 Some(model.clone()),
4191 cx,
4192 )
4193 });
4194
4195 let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
4196 std::collections::BTreeMap::new();
4197
4198 #[allow(clippy::arc_with_non_send_sync)]
4199 let tool = Arc::new(SubagentTool::new(
4200 parent.downgrade(),
4201 project.clone(),
4202 project_context,
4203 context_server_registry,
4204 Templates::new(),
4205 0,
4206 parent_tools,
4207 ));
4208
4209 let (event_stream, _rx, mut cancellation_tx) =
4210 crate::ToolCallEventStream::test_with_cancellation();
4211
4212 // Start the subagent tool
4213 let task = cx.update(|cx| {
4214 tool.run(
4215 SubagentToolInput {
4216 subagents: vec![crate::SubagentConfig {
4217 label: "Long running task".to_string(),
4218 task_prompt: "Do a very long task that takes forever".to_string(),
4219 summary_prompt: "Summarize".to_string(),
4220 context_low_prompt: "Context low".to_string(),
4221 timeout_ms: None,
4222 allowed_tools: None,
4223 }],
4224 },
4225 event_stream.clone(),
4226 cx,
4227 )
4228 });
4229
4230 cx.run_until_parked();
4231
4232 // Signal cancellation via the event stream
4233 crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4234
4235 // The task should complete promptly with a cancellation error
4236 let timeout = cx.background_executor.timer(Duration::from_secs(5));
4237 let result = futures::select! {
4238 result = task.fuse() => result,
4239 _ = timeout.fuse() => {
4240 panic!("subagent tool did not respond to cancellation within timeout");
4241 }
4242 };
4243
4244 // Verify we got a cancellation error
4245 let err = result.unwrap_err();
4246 assert!(
4247 err.to_string().contains("cancelled by user"),
4248 "expected cancellation error, got: {}",
4249 err
4250 );
4251}
4252
4253#[gpui::test]
4254async fn test_subagent_model_error_returned_as_tool_error(cx: &mut TestAppContext) {
4255 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4256 let fake_model = model.as_fake();
4257
4258 cx.update(|cx| {
4259 cx.update_flags(true, vec!["subagents".to_string()]);
4260 });
4261
4262 let subagent_context = SubagentContext {
4263 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4264 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4265 depth: 1,
4266 summary_prompt: "Summarize".to_string(),
4267 context_low_prompt: "Context low".to_string(),
4268 };
4269
4270 let project = thread.read_with(cx, |t, _| t.project.clone());
4271 let project_context = cx.new(|_cx| ProjectContext::default());
4272 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4273 let context_server_registry =
4274 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4275
4276 let subagent = cx.new(|cx| {
4277 Thread::new_subagent(
4278 project.clone(),
4279 project_context,
4280 context_server_registry,
4281 Templates::new(),
4282 model.clone(),
4283 subagent_context,
4284 std::collections::BTreeMap::new(),
4285 cx,
4286 )
4287 });
4288
4289 subagent
4290 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4291 .unwrap();
4292 cx.run_until_parked();
4293
4294 subagent.read_with(cx, |thread, _| {
4295 assert!(!thread.is_turn_complete(), "turn should be in progress");
4296 });
4297
4298 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::NoApiKey {
4299 provider: LanguageModelProviderName::from("Fake".to_string()),
4300 });
4301 fake_model.end_last_completion_stream();
4302 cx.run_until_parked();
4303
4304 subagent.read_with(cx, |thread, _| {
4305 assert!(
4306 thread.is_turn_complete(),
4307 "turn should be complete after non-retryable error"
4308 );
4309 });
4310}
4311
4312#[gpui::test]
4313async fn test_subagent_timeout_triggers_early_summary(cx: &mut TestAppContext) {
4314 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4315 let fake_model = model.as_fake();
4316
4317 cx.update(|cx| {
4318 cx.update_flags(true, vec!["subagents".to_string()]);
4319 });
4320
4321 let subagent_context = SubagentContext {
4322 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4323 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4324 depth: 1,
4325 summary_prompt: "Summarize your work".to_string(),
4326 context_low_prompt: "Context low, stop and summarize".to_string(),
4327 };
4328
4329 let project = thread.read_with(cx, |t, _| t.project.clone());
4330 let project_context = cx.new(|_cx| ProjectContext::default());
4331 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4332 let context_server_registry =
4333 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4334
4335 let subagent = cx.new(|cx| {
4336 Thread::new_subagent(
4337 project.clone(),
4338 project_context.clone(),
4339 context_server_registry.clone(),
4340 Templates::new(),
4341 model.clone(),
4342 subagent_context.clone(),
4343 std::collections::BTreeMap::new(),
4344 cx,
4345 )
4346 });
4347
4348 subagent.update(cx, |thread, _| {
4349 thread.add_tool(EchoTool);
4350 });
4351
4352 subagent
4353 .update(cx, |thread, cx| {
4354 thread.submit_user_message("Do some work", cx)
4355 })
4356 .unwrap();
4357 cx.run_until_parked();
4358
4359 fake_model.send_last_completion_stream_text_chunk("Working on it...");
4360 fake_model
4361 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4362 fake_model.end_last_completion_stream();
4363 cx.run_until_parked();
4364
4365 let interrupt_result = subagent.update(cx, |thread, cx| thread.interrupt_for_summary(cx));
4366 assert!(
4367 interrupt_result.is_ok(),
4368 "interrupt_for_summary should succeed"
4369 );
4370
4371 cx.run_until_parked();
4372
4373 let pending = fake_model.pending_completions();
4374 assert!(
4375 !pending.is_empty(),
4376 "should have pending completion for interrupted summary"
4377 );
4378
4379 let messages = &pending.last().unwrap().messages;
4380 let user_messages: Vec<_> = messages
4381 .iter()
4382 .filter(|m| m.role == language_model::Role::User)
4383 .collect();
4384
4385 let last_user = user_messages.last().unwrap();
4386 let content_str = last_user.content[0].to_str().unwrap();
4387 assert!(
4388 content_str.contains("Context low") || content_str.contains("stop and summarize"),
4389 "context_low_prompt should be sent when interrupting: got {:?}",
4390 content_str
4391 );
4392}
4393
4394#[gpui::test]
4395async fn test_context_low_check_returns_true_when_usage_high(cx: &mut TestAppContext) {
4396 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4397 let fake_model = model.as_fake();
4398
4399 cx.update(|cx| {
4400 cx.update_flags(true, vec!["subagents".to_string()]);
4401 });
4402
4403 let subagent_context = SubagentContext {
4404 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4405 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4406 depth: 1,
4407 summary_prompt: "Summarize".to_string(),
4408 context_low_prompt: "Context low".to_string(),
4409 };
4410
4411 let project = thread.read_with(cx, |t, _| t.project.clone());
4412 let project_context = cx.new(|_cx| ProjectContext::default());
4413 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4414 let context_server_registry =
4415 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4416
4417 let subagent = cx.new(|cx| {
4418 Thread::new_subagent(
4419 project.clone(),
4420 project_context,
4421 context_server_registry,
4422 Templates::new(),
4423 model.clone(),
4424 subagent_context,
4425 std::collections::BTreeMap::new(),
4426 cx,
4427 )
4428 });
4429
4430 subagent
4431 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4432 .unwrap();
4433 cx.run_until_parked();
4434
4435 let max_tokens = model.max_token_count();
4436 let high_usage = language_model::TokenUsage {
4437 input_tokens: (max_tokens as f64 * 0.80) as u64,
4438 output_tokens: 0,
4439 cache_creation_input_tokens: 0,
4440 cache_read_input_tokens: 0,
4441 };
4442
4443 fake_model
4444 .send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(high_usage));
4445 fake_model.send_last_completion_stream_text_chunk("Working...");
4446 fake_model
4447 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4448 fake_model.end_last_completion_stream();
4449 cx.run_until_parked();
4450
4451 let usage = subagent.read_with(cx, |thread, _| thread.latest_token_usage());
4452 assert!(usage.is_some(), "should have token usage after completion");
4453
4454 let usage = usage.unwrap();
4455 let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
4456 assert!(
4457 remaining_ratio <= 0.25,
4458 "remaining ratio should be at or below 25% (got {}%), indicating context is low",
4459 remaining_ratio * 100.0
4460 );
4461}
4462
4463#[gpui::test]
4464async fn test_allowed_tools_rejects_unknown_tool(cx: &mut TestAppContext) {
4465 init_test(cx);
4466
4467 cx.update(|cx| {
4468 cx.update_flags(true, vec!["subagents".to_string()]);
4469 });
4470
4471 let fs = FakeFs::new(cx.executor());
4472 fs.insert_tree(path!("/test"), json!({})).await;
4473 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4474 let project_context = cx.new(|_cx| ProjectContext::default());
4475 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4476 let context_server_registry =
4477 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4478 let model = Arc::new(FakeLanguageModel::default());
4479
4480 let parent = cx.new(|cx| {
4481 let mut thread = Thread::new(
4482 project.clone(),
4483 project_context.clone(),
4484 context_server_registry.clone(),
4485 Templates::new(),
4486 Some(model.clone()),
4487 cx,
4488 );
4489 thread.add_tool(EchoTool);
4490 thread
4491 });
4492
4493 let mut parent_tools: std::collections::BTreeMap<
4494 gpui::SharedString,
4495 Arc<dyn crate::AnyAgentTool>,
4496 > = std::collections::BTreeMap::new();
4497 parent_tools.insert("echo".into(), EchoTool.erase());
4498
4499 #[allow(clippy::arc_with_non_send_sync)]
4500 let tool = Arc::new(SubagentTool::new(
4501 parent.downgrade(),
4502 project,
4503 project_context,
4504 context_server_registry,
4505 Templates::new(),
4506 0,
4507 parent_tools,
4508 ));
4509
4510 let subagent_configs = vec![crate::SubagentConfig {
4511 label: "Test".to_string(),
4512 task_prompt: "Do something".to_string(),
4513 summary_prompt: "Summarize".to_string(),
4514 context_low_prompt: "Context low".to_string(),
4515 timeout_ms: None,
4516 allowed_tools: Some(vec!["nonexistent_tool".to_string()]),
4517 }];
4518 let result = tool.validate_subagents(&subagent_configs);
4519 assert!(result.is_err(), "should reject unknown tool");
4520 let err_msg = result.unwrap_err().to_string();
4521 assert!(
4522 err_msg.contains("nonexistent_tool"),
4523 "error should mention the invalid tool name: {}",
4524 err_msg
4525 );
4526 assert!(
4527 err_msg.contains("do not exist"),
4528 "error should explain the tool does not exist: {}",
4529 err_msg
4530 );
4531}
4532
4533#[gpui::test]
4534async fn test_subagent_empty_response_handled(cx: &mut TestAppContext) {
4535 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4536 let fake_model = model.as_fake();
4537
4538 cx.update(|cx| {
4539 cx.update_flags(true, vec!["subagents".to_string()]);
4540 });
4541
4542 let subagent_context = SubagentContext {
4543 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4544 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4545 depth: 1,
4546 summary_prompt: "Summarize".to_string(),
4547 context_low_prompt: "Context low".to_string(),
4548 };
4549
4550 let project = thread.read_with(cx, |t, _| t.project.clone());
4551 let project_context = cx.new(|_cx| ProjectContext::default());
4552 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4553 let context_server_registry =
4554 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4555
4556 let subagent = cx.new(|cx| {
4557 Thread::new_subagent(
4558 project.clone(),
4559 project_context,
4560 context_server_registry,
4561 Templates::new(),
4562 model.clone(),
4563 subagent_context,
4564 std::collections::BTreeMap::new(),
4565 cx,
4566 )
4567 });
4568
4569 subagent
4570 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4571 .unwrap();
4572 cx.run_until_parked();
4573
4574 fake_model
4575 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4576 fake_model.end_last_completion_stream();
4577 cx.run_until_parked();
4578
4579 subagent.read_with(cx, |thread, _| {
4580 assert!(
4581 thread.is_turn_complete(),
4582 "turn should complete even with empty response"
4583 );
4584 });
4585}
4586
4587#[gpui::test]
4588async fn test_nested_subagent_at_depth_2_succeeds(cx: &mut TestAppContext) {
4589 init_test(cx);
4590
4591 cx.update(|cx| {
4592 cx.update_flags(true, vec!["subagents".to_string()]);
4593 });
4594
4595 let fs = FakeFs::new(cx.executor());
4596 fs.insert_tree(path!("/test"), json!({})).await;
4597 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4598 let project_context = cx.new(|_cx| ProjectContext::default());
4599 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4600 let context_server_registry =
4601 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4602 let model = Arc::new(FakeLanguageModel::default());
4603
4604 let depth_1_context = SubagentContext {
4605 parent_thread_id: agent_client_protocol::SessionId::new("root-id"),
4606 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-1"),
4607 depth: 1,
4608 summary_prompt: "Summarize".to_string(),
4609 context_low_prompt: "Context low".to_string(),
4610 };
4611
4612 let depth_1_subagent = cx.new(|cx| {
4613 Thread::new_subagent(
4614 project.clone(),
4615 project_context.clone(),
4616 context_server_registry.clone(),
4617 Templates::new(),
4618 model.clone(),
4619 depth_1_context,
4620 std::collections::BTreeMap::new(),
4621 cx,
4622 )
4623 });
4624
4625 depth_1_subagent.read_with(cx, |thread, _| {
4626 assert_eq!(thread.depth(), 1);
4627 assert!(thread.is_subagent());
4628 });
4629
4630 let depth_2_context = SubagentContext {
4631 parent_thread_id: agent_client_protocol::SessionId::new("depth-1-id"),
4632 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-2"),
4633 depth: 2,
4634 summary_prompt: "Summarize depth 2".to_string(),
4635 context_low_prompt: "Context low depth 2".to_string(),
4636 };
4637
4638 let depth_2_subagent = cx.new(|cx| {
4639 Thread::new_subagent(
4640 project.clone(),
4641 project_context.clone(),
4642 context_server_registry.clone(),
4643 Templates::new(),
4644 model.clone(),
4645 depth_2_context,
4646 std::collections::BTreeMap::new(),
4647 cx,
4648 )
4649 });
4650
4651 depth_2_subagent.read_with(cx, |thread, _| {
4652 assert_eq!(thread.depth(), 2);
4653 assert!(thread.is_subagent());
4654 });
4655
4656 depth_2_subagent
4657 .update(cx, |thread, cx| {
4658 thread.submit_user_message("Nested task", cx)
4659 })
4660 .unwrap();
4661 cx.run_until_parked();
4662
4663 let pending = model.as_fake().pending_completions();
4664 assert!(
4665 !pending.is_empty(),
4666 "depth-2 subagent should be able to submit messages"
4667 );
4668}
4669
4670#[gpui::test]
4671async fn test_subagent_uses_tool_and_returns_result(cx: &mut TestAppContext) {
4672 init_test(cx);
4673 always_allow_tools(cx);
4674
4675 cx.update(|cx| {
4676 cx.update_flags(true, vec!["subagents".to_string()]);
4677 });
4678
4679 let fs = FakeFs::new(cx.executor());
4680 fs.insert_tree(path!("/test"), json!({})).await;
4681 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4682 let project_context = cx.new(|_cx| ProjectContext::default());
4683 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4684 let context_server_registry =
4685 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4686 let model = Arc::new(FakeLanguageModel::default());
4687 let fake_model = model.as_fake();
4688
4689 let subagent_context = SubagentContext {
4690 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4691 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4692 depth: 1,
4693 summary_prompt: "Summarize what you did".to_string(),
4694 context_low_prompt: "Context low".to_string(),
4695 };
4696
4697 let subagent = cx.new(|cx| {
4698 let mut thread = Thread::new_subagent(
4699 project.clone(),
4700 project_context,
4701 context_server_registry,
4702 Templates::new(),
4703 model.clone(),
4704 subagent_context,
4705 std::collections::BTreeMap::new(),
4706 cx,
4707 );
4708 thread.add_tool(EchoTool);
4709 thread
4710 });
4711
4712 subagent.read_with(cx, |thread, _| {
4713 assert!(
4714 thread.has_registered_tool("echo"),
4715 "subagent should have echo tool"
4716 );
4717 });
4718
4719 subagent
4720 .update(cx, |thread, cx| {
4721 thread.submit_user_message("Use the echo tool to echo 'hello world'", cx)
4722 })
4723 .unwrap();
4724 cx.run_until_parked();
4725
4726 let tool_use = LanguageModelToolUse {
4727 id: "tool_call_1".into(),
4728 name: EchoTool::name().into(),
4729 raw_input: json!({"text": "hello world"}).to_string(),
4730 input: json!({"text": "hello world"}),
4731 is_input_complete: true,
4732 thought_signature: None,
4733 };
4734 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
4735 fake_model.end_last_completion_stream();
4736 cx.run_until_parked();
4737
4738 let pending = fake_model.pending_completions();
4739 assert!(
4740 !pending.is_empty(),
4741 "should have pending completion after tool use"
4742 );
4743
4744 let last_completion = pending.last().unwrap();
4745 let has_tool_result = last_completion.messages.iter().any(|m| {
4746 m.content
4747 .iter()
4748 .any(|c| matches!(c, MessageContent::ToolResult(_)))
4749 });
4750 assert!(
4751 has_tool_result,
4752 "tool result should be in the messages sent back to the model"
4753 );
4754}
4755
4756#[gpui::test]
4757async fn test_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4758 init_test(cx);
4759
4760 cx.update(|cx| {
4761 cx.update_flags(true, vec!["subagents".to_string()]);
4762 });
4763
4764 let fs = FakeFs::new(cx.executor());
4765 fs.insert_tree(path!("/test"), json!({})).await;
4766 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4767 let project_context = cx.new(|_cx| ProjectContext::default());
4768 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4769 let context_server_registry =
4770 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4771 let model = Arc::new(FakeLanguageModel::default());
4772
4773 let parent = cx.new(|cx| {
4774 Thread::new(
4775 project.clone(),
4776 project_context.clone(),
4777 context_server_registry.clone(),
4778 Templates::new(),
4779 Some(model.clone()),
4780 cx,
4781 )
4782 });
4783
4784 let mut subagents = Vec::new();
4785 for i in 0..MAX_PARALLEL_SUBAGENTS {
4786 let subagent_context = SubagentContext {
4787 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4788 tool_use_id: language_model::LanguageModelToolUseId::from(format!("tool-use-{}", i)),
4789 depth: 1,
4790 summary_prompt: "Summarize".to_string(),
4791 context_low_prompt: "Context low".to_string(),
4792 };
4793
4794 let subagent = cx.new(|cx| {
4795 Thread::new_subagent(
4796 project.clone(),
4797 project_context.clone(),
4798 context_server_registry.clone(),
4799 Templates::new(),
4800 model.clone(),
4801 subagent_context,
4802 std::collections::BTreeMap::new(),
4803 cx,
4804 )
4805 });
4806
4807 parent.update(cx, |thread, _cx| {
4808 thread.register_running_subagent(subagent.downgrade());
4809 });
4810 subagents.push(subagent);
4811 }
4812
4813 parent.read_with(cx, |thread, _| {
4814 assert_eq!(
4815 thread.running_subagent_count(),
4816 MAX_PARALLEL_SUBAGENTS,
4817 "should have MAX_PARALLEL_SUBAGENTS registered"
4818 );
4819 });
4820
4821 let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
4822 std::collections::BTreeMap::new();
4823
4824 #[allow(clippy::arc_with_non_send_sync)]
4825 let tool = Arc::new(SubagentTool::new(
4826 parent.downgrade(),
4827 project.clone(),
4828 project_context,
4829 context_server_registry,
4830 Templates::new(),
4831 0,
4832 parent_tools,
4833 ));
4834
4835 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4836
4837 let result = cx.update(|cx| {
4838 tool.run(
4839 SubagentToolInput {
4840 subagents: vec![crate::SubagentConfig {
4841 label: "Test".to_string(),
4842 task_prompt: "Do something".to_string(),
4843 summary_prompt: "Summarize".to_string(),
4844 context_low_prompt: "Context low".to_string(),
4845 timeout_ms: None,
4846 allowed_tools: None,
4847 }],
4848 },
4849 event_stream,
4850 cx,
4851 )
4852 });
4853
4854 let err = result.await.unwrap_err();
4855 assert!(
4856 err.to_string().contains("Maximum parallel subagents"),
4857 "should reject when max parallel subagents reached: {}",
4858 err
4859 );
4860
4861 drop(subagents);
4862}
4863
4864#[gpui::test]
4865async fn test_subagent_tool_end_to_end(cx: &mut TestAppContext) {
4866 init_test(cx);
4867 always_allow_tools(cx);
4868
4869 cx.update(|cx| {
4870 cx.update_flags(true, vec!["subagents".to_string()]);
4871 });
4872
4873 let fs = FakeFs::new(cx.executor());
4874 fs.insert_tree(path!("/test"), json!({})).await;
4875 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4876 let project_context = cx.new(|_cx| ProjectContext::default());
4877 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4878 let context_server_registry =
4879 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4880 let model = Arc::new(FakeLanguageModel::default());
4881 let fake_model = model.as_fake();
4882
4883 let parent = cx.new(|cx| {
4884 let mut thread = Thread::new(
4885 project.clone(),
4886 project_context.clone(),
4887 context_server_registry.clone(),
4888 Templates::new(),
4889 Some(model.clone()),
4890 cx,
4891 );
4892 thread.add_tool(EchoTool);
4893 thread
4894 });
4895
4896 let mut parent_tools: std::collections::BTreeMap<
4897 gpui::SharedString,
4898 Arc<dyn crate::AnyAgentTool>,
4899 > = std::collections::BTreeMap::new();
4900 parent_tools.insert("echo".into(), EchoTool.erase());
4901
4902 #[allow(clippy::arc_with_non_send_sync)]
4903 let tool = Arc::new(SubagentTool::new(
4904 parent.downgrade(),
4905 project.clone(),
4906 project_context,
4907 context_server_registry,
4908 Templates::new(),
4909 0,
4910 parent_tools,
4911 ));
4912
4913 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4914
4915 let task = cx.update(|cx| {
4916 tool.run(
4917 SubagentToolInput {
4918 subagents: vec![crate::SubagentConfig {
4919 label: "Research task".to_string(),
4920 task_prompt: "Find all TODOs in the codebase".to_string(),
4921 summary_prompt: "Summarize what you found".to_string(),
4922 context_low_prompt: "Context low, wrap up".to_string(),
4923 timeout_ms: None,
4924 allowed_tools: None,
4925 }],
4926 },
4927 event_stream,
4928 cx,
4929 )
4930 });
4931
4932 cx.run_until_parked();
4933
4934 let pending = fake_model.pending_completions();
4935 assert!(
4936 !pending.is_empty(),
4937 "subagent should have started and sent a completion request"
4938 );
4939
4940 let first_completion = &pending[0];
4941 let has_task_prompt = first_completion.messages.iter().any(|m| {
4942 m.role == language_model::Role::User
4943 && m.content
4944 .iter()
4945 .any(|c| c.to_str().map(|s| s.contains("TODO")).unwrap_or(false))
4946 });
4947 assert!(has_task_prompt, "task prompt should be sent to subagent");
4948
4949 fake_model.send_last_completion_stream_text_chunk("I found 5 TODOs in the codebase.");
4950 fake_model
4951 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4952 fake_model.end_last_completion_stream();
4953 cx.run_until_parked();
4954
4955 let pending = fake_model.pending_completions();
4956 assert!(
4957 !pending.is_empty(),
4958 "should have pending completion for summary request"
4959 );
4960
4961 let last_completion = pending.last().unwrap();
4962 let has_summary_prompt = last_completion.messages.iter().any(|m| {
4963 m.role == language_model::Role::User
4964 && m.content.iter().any(|c| {
4965 c.to_str()
4966 .map(|s| s.contains("Summarize") || s.contains("summarize"))
4967 .unwrap_or(false)
4968 })
4969 });
4970 assert!(
4971 has_summary_prompt,
4972 "summary prompt should be sent after task completion"
4973 );
4974
4975 fake_model.send_last_completion_stream_text_chunk("Summary: Found 5 TODOs across 3 files.");
4976 fake_model
4977 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4978 fake_model.end_last_completion_stream();
4979 cx.run_until_parked();
4980
4981 let result = task.await;
4982 assert!(result.is_ok(), "subagent tool should complete successfully");
4983
4984 let summary = result.unwrap();
4985 assert!(
4986 summary.contains("Summary") || summary.contains("TODO") || summary.contains("5"),
4987 "summary should contain subagent's response: {}",
4988 summary
4989 );
4990}
4991
4992#[gpui::test]
4993async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
4994 init_test(cx);
4995
4996 let fs = FakeFs::new(cx.executor());
4997 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
4998 .await;
4999 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5000
5001 cx.update(|cx| {
5002 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5003 settings.tool_permissions.tools.insert(
5004 "edit_file".into(),
5005 agent_settings::ToolRules {
5006 default_mode: settings::ToolPermissionMode::Allow,
5007 always_allow: vec![],
5008 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5009 always_confirm: vec![],
5010 invalid_patterns: vec![],
5011 },
5012 );
5013 agent_settings::AgentSettings::override_global(settings, cx);
5014 });
5015
5016 let context_server_registry =
5017 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5018 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5019 let templates = crate::Templates::new();
5020 let thread = cx.new(|cx| {
5021 crate::Thread::new(
5022 project.clone(),
5023 cx.new(|_cx| prompt_store::ProjectContext::default()),
5024 context_server_registry,
5025 templates.clone(),
5026 None,
5027 cx,
5028 )
5029 });
5030
5031 #[allow(clippy::arc_with_non_send_sync)]
5032 let tool = Arc::new(crate::EditFileTool::new(
5033 project.clone(),
5034 thread.downgrade(),
5035 language_registry,
5036 templates,
5037 ));
5038 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5039
5040 let task = cx.update(|cx| {
5041 tool.run(
5042 crate::EditFileToolInput {
5043 display_description: "Edit sensitive file".to_string(),
5044 path: "root/sensitive_config.txt".into(),
5045 mode: crate::EditFileMode::Edit,
5046 },
5047 event_stream,
5048 cx,
5049 )
5050 });
5051
5052 let result = task.await;
5053 assert!(result.is_err(), "expected edit to be blocked");
5054 assert!(
5055 result.unwrap_err().to_string().contains("blocked"),
5056 "error should mention the edit was blocked"
5057 );
5058}
5059
5060#[gpui::test]
5061async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5062 init_test(cx);
5063
5064 let fs = FakeFs::new(cx.executor());
5065 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5066 .await;
5067 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5068
5069 cx.update(|cx| {
5070 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5071 settings.tool_permissions.tools.insert(
5072 "delete_path".into(),
5073 agent_settings::ToolRules {
5074 default_mode: settings::ToolPermissionMode::Allow,
5075 always_allow: vec![],
5076 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5077 always_confirm: vec![],
5078 invalid_patterns: vec![],
5079 },
5080 );
5081 agent_settings::AgentSettings::override_global(settings, cx);
5082 });
5083
5084 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5085
5086 #[allow(clippy::arc_with_non_send_sync)]
5087 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5088 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5089
5090 let task = cx.update(|cx| {
5091 tool.run(
5092 crate::DeletePathToolInput {
5093 path: "root/important_data.txt".to_string(),
5094 },
5095 event_stream,
5096 cx,
5097 )
5098 });
5099
5100 let result = task.await;
5101 assert!(result.is_err(), "expected deletion to be blocked");
5102 assert!(
5103 result.unwrap_err().to_string().contains("blocked"),
5104 "error should mention the deletion was blocked"
5105 );
5106}
5107
5108#[gpui::test]
5109async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5110 init_test(cx);
5111
5112 let fs = FakeFs::new(cx.executor());
5113 fs.insert_tree(
5114 "/root",
5115 json!({
5116 "safe.txt": "content",
5117 "protected": {}
5118 }),
5119 )
5120 .await;
5121 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5122
5123 cx.update(|cx| {
5124 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5125 settings.tool_permissions.tools.insert(
5126 "move_path".into(),
5127 agent_settings::ToolRules {
5128 default_mode: settings::ToolPermissionMode::Allow,
5129 always_allow: vec![],
5130 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5131 always_confirm: vec![],
5132 invalid_patterns: vec![],
5133 },
5134 );
5135 agent_settings::AgentSettings::override_global(settings, cx);
5136 });
5137
5138 #[allow(clippy::arc_with_non_send_sync)]
5139 let tool = Arc::new(crate::MovePathTool::new(project));
5140 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5141
5142 let task = cx.update(|cx| {
5143 tool.run(
5144 crate::MovePathToolInput {
5145 source_path: "root/safe.txt".to_string(),
5146 destination_path: "root/protected/safe.txt".to_string(),
5147 },
5148 event_stream,
5149 cx,
5150 )
5151 });
5152
5153 let result = task.await;
5154 assert!(
5155 result.is_err(),
5156 "expected move to be blocked due to destination path"
5157 );
5158 assert!(
5159 result.unwrap_err().to_string().contains("blocked"),
5160 "error should mention the move was blocked"
5161 );
5162}
5163
5164#[gpui::test]
5165async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5166 init_test(cx);
5167
5168 let fs = FakeFs::new(cx.executor());
5169 fs.insert_tree(
5170 "/root",
5171 json!({
5172 "secret.txt": "secret content",
5173 "public": {}
5174 }),
5175 )
5176 .await;
5177 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5178
5179 cx.update(|cx| {
5180 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5181 settings.tool_permissions.tools.insert(
5182 "move_path".into(),
5183 agent_settings::ToolRules {
5184 default_mode: settings::ToolPermissionMode::Allow,
5185 always_allow: vec![],
5186 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5187 always_confirm: vec![],
5188 invalid_patterns: vec![],
5189 },
5190 );
5191 agent_settings::AgentSettings::override_global(settings, cx);
5192 });
5193
5194 #[allow(clippy::arc_with_non_send_sync)]
5195 let tool = Arc::new(crate::MovePathTool::new(project));
5196 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5197
5198 let task = cx.update(|cx| {
5199 tool.run(
5200 crate::MovePathToolInput {
5201 source_path: "root/secret.txt".to_string(),
5202 destination_path: "root/public/not_secret.txt".to_string(),
5203 },
5204 event_stream,
5205 cx,
5206 )
5207 });
5208
5209 let result = task.await;
5210 assert!(
5211 result.is_err(),
5212 "expected move to be blocked due to source path"
5213 );
5214 assert!(
5215 result.unwrap_err().to_string().contains("blocked"),
5216 "error should mention the move was blocked"
5217 );
5218}
5219
5220#[gpui::test]
5221async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5222 init_test(cx);
5223
5224 let fs = FakeFs::new(cx.executor());
5225 fs.insert_tree(
5226 "/root",
5227 json!({
5228 "confidential.txt": "confidential data",
5229 "dest": {}
5230 }),
5231 )
5232 .await;
5233 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5234
5235 cx.update(|cx| {
5236 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5237 settings.tool_permissions.tools.insert(
5238 "copy_path".into(),
5239 agent_settings::ToolRules {
5240 default_mode: settings::ToolPermissionMode::Allow,
5241 always_allow: vec![],
5242 always_deny: vec![
5243 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5244 ],
5245 always_confirm: vec![],
5246 invalid_patterns: vec![],
5247 },
5248 );
5249 agent_settings::AgentSettings::override_global(settings, cx);
5250 });
5251
5252 #[allow(clippy::arc_with_non_send_sync)]
5253 let tool = Arc::new(crate::CopyPathTool::new(project));
5254 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5255
5256 let task = cx.update(|cx| {
5257 tool.run(
5258 crate::CopyPathToolInput {
5259 source_path: "root/confidential.txt".to_string(),
5260 destination_path: "root/dest/copy.txt".to_string(),
5261 },
5262 event_stream,
5263 cx,
5264 )
5265 });
5266
5267 let result = task.await;
5268 assert!(result.is_err(), "expected copy to be blocked");
5269 assert!(
5270 result.unwrap_err().to_string().contains("blocked"),
5271 "error should mention the copy was blocked"
5272 );
5273}
5274
5275#[gpui::test]
5276async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5277 init_test(cx);
5278
5279 let fs = FakeFs::new(cx.executor());
5280 fs.insert_tree(
5281 "/root",
5282 json!({
5283 "normal.txt": "normal content",
5284 "readonly": {
5285 "config.txt": "readonly content"
5286 }
5287 }),
5288 )
5289 .await;
5290 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5291
5292 cx.update(|cx| {
5293 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5294 settings.tool_permissions.tools.insert(
5295 "save_file".into(),
5296 agent_settings::ToolRules {
5297 default_mode: settings::ToolPermissionMode::Allow,
5298 always_allow: vec![],
5299 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5300 always_confirm: vec![],
5301 invalid_patterns: vec![],
5302 },
5303 );
5304 agent_settings::AgentSettings::override_global(settings, cx);
5305 });
5306
5307 #[allow(clippy::arc_with_non_send_sync)]
5308 let tool = Arc::new(crate::SaveFileTool::new(project));
5309 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5310
5311 let task = cx.update(|cx| {
5312 tool.run(
5313 crate::SaveFileToolInput {
5314 paths: vec![
5315 std::path::PathBuf::from("root/normal.txt"),
5316 std::path::PathBuf::from("root/readonly/config.txt"),
5317 ],
5318 },
5319 event_stream,
5320 cx,
5321 )
5322 });
5323
5324 let result = task.await;
5325 assert!(
5326 result.is_err(),
5327 "expected save to be blocked due to denied path"
5328 );
5329 assert!(
5330 result.unwrap_err().to_string().contains("blocked"),
5331 "error should mention the save was blocked"
5332 );
5333}
5334
5335#[gpui::test]
5336async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5337 init_test(cx);
5338
5339 let fs = FakeFs::new(cx.executor());
5340 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5341 .await;
5342 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5343
5344 cx.update(|cx| {
5345 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5346 settings.always_allow_tool_actions = false;
5347 settings.tool_permissions.tools.insert(
5348 "save_file".into(),
5349 agent_settings::ToolRules {
5350 default_mode: settings::ToolPermissionMode::Allow,
5351 always_allow: vec![],
5352 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5353 always_confirm: vec![],
5354 invalid_patterns: vec![],
5355 },
5356 );
5357 agent_settings::AgentSettings::override_global(settings, cx);
5358 });
5359
5360 #[allow(clippy::arc_with_non_send_sync)]
5361 let tool = Arc::new(crate::SaveFileTool::new(project));
5362 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5363
5364 let task = cx.update(|cx| {
5365 tool.run(
5366 crate::SaveFileToolInput {
5367 paths: vec![std::path::PathBuf::from("root/config.secret")],
5368 },
5369 event_stream,
5370 cx,
5371 )
5372 });
5373
5374 let result = task.await;
5375 assert!(result.is_err(), "expected save to be blocked");
5376 assert!(
5377 result.unwrap_err().to_string().contains("blocked"),
5378 "error should mention the save was blocked"
5379 );
5380}
5381
5382#[gpui::test]
5383async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5384 init_test(cx);
5385
5386 cx.update(|cx| {
5387 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5388 settings.tool_permissions.tools.insert(
5389 "web_search".into(),
5390 agent_settings::ToolRules {
5391 default_mode: settings::ToolPermissionMode::Allow,
5392 always_allow: vec![],
5393 always_deny: vec![
5394 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5395 ],
5396 always_confirm: vec![],
5397 invalid_patterns: vec![],
5398 },
5399 );
5400 agent_settings::AgentSettings::override_global(settings, cx);
5401 });
5402
5403 #[allow(clippy::arc_with_non_send_sync)]
5404 let tool = Arc::new(crate::WebSearchTool);
5405 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5406
5407 let input: crate::WebSearchToolInput =
5408 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5409
5410 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5411
5412 let result = task.await;
5413 assert!(result.is_err(), "expected search to be blocked");
5414 assert!(
5415 result.unwrap_err().to_string().contains("blocked"),
5416 "error should mention the search was blocked"
5417 );
5418}
5419
5420#[gpui::test]
5421async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5422 init_test(cx);
5423
5424 let fs = FakeFs::new(cx.executor());
5425 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5426 .await;
5427 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5428
5429 cx.update(|cx| {
5430 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5431 settings.always_allow_tool_actions = false;
5432 settings.tool_permissions.tools.insert(
5433 "edit_file".into(),
5434 agent_settings::ToolRules {
5435 default_mode: settings::ToolPermissionMode::Confirm,
5436 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5437 always_deny: vec![],
5438 always_confirm: vec![],
5439 invalid_patterns: vec![],
5440 },
5441 );
5442 agent_settings::AgentSettings::override_global(settings, cx);
5443 });
5444
5445 let context_server_registry =
5446 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5447 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5448 let templates = crate::Templates::new();
5449 let thread = cx.new(|cx| {
5450 crate::Thread::new(
5451 project.clone(),
5452 cx.new(|_cx| prompt_store::ProjectContext::default()),
5453 context_server_registry,
5454 templates.clone(),
5455 None,
5456 cx,
5457 )
5458 });
5459
5460 #[allow(clippy::arc_with_non_send_sync)]
5461 let tool = Arc::new(crate::EditFileTool::new(
5462 project,
5463 thread.downgrade(),
5464 language_registry,
5465 templates,
5466 ));
5467 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5468
5469 let _task = cx.update(|cx| {
5470 tool.run(
5471 crate::EditFileToolInput {
5472 display_description: "Edit README".to_string(),
5473 path: "root/README.md".into(),
5474 mode: crate::EditFileMode::Edit,
5475 },
5476 event_stream,
5477 cx,
5478 )
5479 });
5480
5481 cx.run_until_parked();
5482
5483 let event = rx.try_next();
5484 assert!(
5485 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5486 "expected no authorization request for allowed .md file"
5487 );
5488}
5489
5490#[gpui::test]
5491async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5492 init_test(cx);
5493
5494 cx.update(|cx| {
5495 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5496 settings.tool_permissions.tools.insert(
5497 "fetch".into(),
5498 agent_settings::ToolRules {
5499 default_mode: settings::ToolPermissionMode::Allow,
5500 always_allow: vec![],
5501 always_deny: vec![
5502 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5503 ],
5504 always_confirm: vec![],
5505 invalid_patterns: vec![],
5506 },
5507 );
5508 agent_settings::AgentSettings::override_global(settings, cx);
5509 });
5510
5511 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5512
5513 #[allow(clippy::arc_with_non_send_sync)]
5514 let tool = Arc::new(crate::FetchTool::new(http_client));
5515 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5516
5517 let input: crate::FetchToolInput =
5518 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5519
5520 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5521
5522 let result = task.await;
5523 assert!(result.is_err(), "expected fetch to be blocked");
5524 assert!(
5525 result.unwrap_err().to_string().contains("blocked"),
5526 "error should mention the fetch was blocked"
5527 );
5528}
5529
5530#[gpui::test]
5531async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5532 init_test(cx);
5533
5534 cx.update(|cx| {
5535 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5536 settings.always_allow_tool_actions = false;
5537 settings.tool_permissions.tools.insert(
5538 "fetch".into(),
5539 agent_settings::ToolRules {
5540 default_mode: settings::ToolPermissionMode::Confirm,
5541 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5542 always_deny: vec![],
5543 always_confirm: vec![],
5544 invalid_patterns: vec![],
5545 },
5546 );
5547 agent_settings::AgentSettings::override_global(settings, cx);
5548 });
5549
5550 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5551
5552 #[allow(clippy::arc_with_non_send_sync)]
5553 let tool = Arc::new(crate::FetchTool::new(http_client));
5554 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5555
5556 let input: crate::FetchToolInput =
5557 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5558
5559 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5560
5561 cx.run_until_parked();
5562
5563 let event = rx.try_next();
5564 assert!(
5565 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5566 "expected no authorization request for allowed docs.rs URL"
5567 );
5568}