1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, UserMessageId,
4};
5use agent_client_protocol::{self as acp};
6use agent_settings::AgentProfileId;
7use anyhow::Result;
8use client::{Client, UserStore};
9use cloud_llm_client::CompletionIntent;
10use collections::IndexMap;
11use context_server::{ContextServer, ContextServerCommand, ContextServerId};
12use feature_flags::FeatureFlagAppExt as _;
13use fs::{FakeFs, Fs};
14use futures::{
15 FutureExt as _, StreamExt,
16 channel::{
17 mpsc::{self, UnboundedReceiver},
18 oneshot,
19 },
20 future::{Fuse, Shared},
21};
22use gpui::{
23 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
24 http_client::FakeHttpClient,
25};
26use indoc::indoc;
27use language_model::{
28 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
29 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
30 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
31 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
32};
33use pretty_assertions::assert_eq;
34use project::{
35 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
36};
37use prompt_store::ProjectContext;
38use reqwest_client::ReqwestClient;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use serde_json::json;
42use settings::{Settings, SettingsStore};
43use std::{
44 path::Path,
45 pin::Pin,
46 rc::Rc,
47 sync::{
48 Arc,
49 atomic::{AtomicBool, Ordering},
50 },
51 time::Duration,
52};
53use util::path;
54
55mod test_tools;
56use test_tools::*;
57
58fn init_test(cx: &mut TestAppContext) {
59 cx.update(|cx| {
60 let settings_store = SettingsStore::test(cx);
61 cx.set_global(settings_store);
62 });
63}
64
65struct FakeTerminalHandle {
66 killed: Arc<AtomicBool>,
67 stopped_by_user: Arc<AtomicBool>,
68 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
69 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
70 output: acp::TerminalOutputResponse,
71 id: acp::TerminalId,
72}
73
74impl FakeTerminalHandle {
75 fn new_never_exits(cx: &mut App) -> Self {
76 let killed = Arc::new(AtomicBool::new(false));
77 let stopped_by_user = Arc::new(AtomicBool::new(false));
78
79 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
80
81 let wait_for_exit = cx
82 .spawn(async move |_cx| {
83 // Wait for the exit signal (sent when kill() is called)
84 let _ = exit_receiver.await;
85 acp::TerminalExitStatus::new()
86 })
87 .shared();
88
89 Self {
90 killed,
91 stopped_by_user,
92 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
93 wait_for_exit,
94 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
95 id: acp::TerminalId::new("fake_terminal".to_string()),
96 }
97 }
98
99 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
100 let killed = Arc::new(AtomicBool::new(false));
101 let stopped_by_user = Arc::new(AtomicBool::new(false));
102 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
103
104 let wait_for_exit = cx
105 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
106 .shared();
107
108 Self {
109 killed,
110 stopped_by_user,
111 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
112 wait_for_exit,
113 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
114 id: acp::TerminalId::new("fake_terminal".to_string()),
115 }
116 }
117
118 fn was_killed(&self) -> bool {
119 self.killed.load(Ordering::SeqCst)
120 }
121
122 fn set_stopped_by_user(&self, stopped: bool) {
123 self.stopped_by_user.store(stopped, Ordering::SeqCst);
124 }
125
126 fn signal_exit(&self) {
127 if let Some(sender) = self.exit_sender.borrow_mut().take() {
128 let _ = sender.send(());
129 }
130 }
131}
132
133impl crate::TerminalHandle for FakeTerminalHandle {
134 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
135 Ok(self.id.clone())
136 }
137
138 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
139 Ok(self.output.clone())
140 }
141
142 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
143 Ok(self.wait_for_exit.clone())
144 }
145
146 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
147 self.killed.store(true, Ordering::SeqCst);
148 self.signal_exit();
149 Ok(())
150 }
151
152 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
153 Ok(self.stopped_by_user.load(Ordering::SeqCst))
154 }
155}
156
157struct FakeThreadEnvironment {
158 handle: Rc<FakeTerminalHandle>,
159}
160
161impl crate::ThreadEnvironment for FakeThreadEnvironment {
162 fn create_terminal(
163 &self,
164 _command: String,
165 _cwd: Option<std::path::PathBuf>,
166 _output_byte_limit: Option<u64>,
167 _cx: &mut AsyncApp,
168 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
169 Task::ready(Ok(self.handle.clone() as Rc<dyn crate::TerminalHandle>))
170 }
171}
172
173/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
174struct MultiTerminalEnvironment {
175 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
176}
177
178impl MultiTerminalEnvironment {
179 fn new() -> Self {
180 Self {
181 handles: std::cell::RefCell::new(Vec::new()),
182 }
183 }
184
185 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
186 self.handles.borrow().clone()
187 }
188}
189
190impl crate::ThreadEnvironment for MultiTerminalEnvironment {
191 fn create_terminal(
192 &self,
193 _command: String,
194 _cwd: Option<std::path::PathBuf>,
195 _output_byte_limit: Option<u64>,
196 cx: &mut AsyncApp,
197 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
198 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
199 self.handles.borrow_mut().push(handle.clone());
200 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
201 }
202}
203
204fn always_allow_tools(cx: &mut TestAppContext) {
205 cx.update(|cx| {
206 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
207 settings.always_allow_tool_actions = true;
208 agent_settings::AgentSettings::override_global(settings, cx);
209 });
210}
211
212#[gpui::test]
213async fn test_echo(cx: &mut TestAppContext) {
214 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
215 let fake_model = model.as_fake();
216
217 let events = thread
218 .update(cx, |thread, cx| {
219 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
220 })
221 .unwrap();
222 cx.run_until_parked();
223 fake_model.send_last_completion_stream_text_chunk("Hello");
224 fake_model
225 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
226 fake_model.end_last_completion_stream();
227
228 let events = events.collect().await;
229 thread.update(cx, |thread, _cx| {
230 assert_eq!(
231 thread.last_message().unwrap().to_markdown(),
232 indoc! {"
233 ## Assistant
234
235 Hello
236 "}
237 )
238 });
239 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
240}
241
242#[gpui::test]
243async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
244 init_test(cx);
245 always_allow_tools(cx);
246
247 let fs = FakeFs::new(cx.executor());
248 let project = Project::test(fs, [], cx).await;
249
250 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
251 let environment = Rc::new(FakeThreadEnvironment {
252 handle: handle.clone(),
253 });
254
255 #[allow(clippy::arc_with_non_send_sync)]
256 let tool = Arc::new(crate::TerminalTool::new(project, environment));
257 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
258
259 let task = cx.update(|cx| {
260 tool.run(
261 crate::TerminalToolInput {
262 command: "sleep 1000".to_string(),
263 cd: ".".to_string(),
264 timeout_ms: Some(5),
265 },
266 event_stream,
267 cx,
268 )
269 });
270
271 let update = rx.expect_update_fields().await;
272 assert!(
273 update.content.iter().any(|blocks| {
274 blocks
275 .iter()
276 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
277 }),
278 "expected tool call update to include terminal content"
279 );
280
281 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
282
283 let deadline = std::time::Instant::now() + Duration::from_millis(500);
284 loop {
285 if let Some(result) = task_future.as_mut().now_or_never() {
286 let result = result.expect("terminal tool task should complete");
287
288 assert!(
289 handle.was_killed(),
290 "expected terminal handle to be killed on timeout"
291 );
292 assert!(
293 result.contains("partial output"),
294 "expected result to include terminal output, got: {result}"
295 );
296 return;
297 }
298
299 if std::time::Instant::now() >= deadline {
300 panic!("timed out waiting for terminal tool task to complete");
301 }
302
303 cx.run_until_parked();
304 cx.background_executor.timer(Duration::from_millis(1)).await;
305 }
306}
307
308#[gpui::test]
309#[ignore]
310async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
311 init_test(cx);
312 always_allow_tools(cx);
313
314 let fs = FakeFs::new(cx.executor());
315 let project = Project::test(fs, [], cx).await;
316
317 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
318 let environment = Rc::new(FakeThreadEnvironment {
319 handle: handle.clone(),
320 });
321
322 #[allow(clippy::arc_with_non_send_sync)]
323 let tool = Arc::new(crate::TerminalTool::new(project, environment));
324 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
325
326 let _task = cx.update(|cx| {
327 tool.run(
328 crate::TerminalToolInput {
329 command: "sleep 1000".to_string(),
330 cd: ".".to_string(),
331 timeout_ms: None,
332 },
333 event_stream,
334 cx,
335 )
336 });
337
338 let update = rx.expect_update_fields().await;
339 assert!(
340 update.content.iter().any(|blocks| {
341 blocks
342 .iter()
343 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
344 }),
345 "expected tool call update to include terminal content"
346 );
347
348 cx.background_executor
349 .timer(Duration::from_millis(25))
350 .await;
351
352 assert!(
353 !handle.was_killed(),
354 "did not expect terminal handle to be killed without a timeout"
355 );
356}
357
358#[gpui::test]
359async fn test_thinking(cx: &mut TestAppContext) {
360 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
361 let fake_model = model.as_fake();
362
363 let events = thread
364 .update(cx, |thread, cx| {
365 thread.send(
366 UserMessageId::new(),
367 [indoc! {"
368 Testing:
369
370 Generate a thinking step where you just think the word 'Think',
371 and have your final answer be 'Hello'
372 "}],
373 cx,
374 )
375 })
376 .unwrap();
377 cx.run_until_parked();
378 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
379 text: "Think".to_string(),
380 signature: None,
381 });
382 fake_model.send_last_completion_stream_text_chunk("Hello");
383 fake_model
384 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
385 fake_model.end_last_completion_stream();
386
387 let events = events.collect().await;
388 thread.update(cx, |thread, _cx| {
389 assert_eq!(
390 thread.last_message().unwrap().to_markdown(),
391 indoc! {"
392 ## Assistant
393
394 <think>Think</think>
395 Hello
396 "}
397 )
398 });
399 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
400}
401
402#[gpui::test]
403async fn test_system_prompt(cx: &mut TestAppContext) {
404 let ThreadTest {
405 model,
406 thread,
407 project_context,
408 ..
409 } = setup(cx, TestModel::Fake).await;
410 let fake_model = model.as_fake();
411
412 project_context.update(cx, |project_context, _cx| {
413 project_context.shell = "test-shell".into()
414 });
415 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
416 thread
417 .update(cx, |thread, cx| {
418 thread.send(UserMessageId::new(), ["abc"], cx)
419 })
420 .unwrap();
421 cx.run_until_parked();
422 let mut pending_completions = fake_model.pending_completions();
423 assert_eq!(
424 pending_completions.len(),
425 1,
426 "unexpected pending completions: {:?}",
427 pending_completions
428 );
429
430 let pending_completion = pending_completions.pop().unwrap();
431 assert_eq!(pending_completion.messages[0].role, Role::System);
432
433 let system_message = &pending_completion.messages[0];
434 let system_prompt = system_message.content[0].to_str().unwrap();
435 assert!(
436 system_prompt.contains("test-shell"),
437 "unexpected system message: {:?}",
438 system_message
439 );
440 assert!(
441 system_prompt.contains("## Fixing Diagnostics"),
442 "unexpected system message: {:?}",
443 system_message
444 );
445}
446
447#[gpui::test]
448async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
449 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
450 let fake_model = model.as_fake();
451
452 thread
453 .update(cx, |thread, cx| {
454 thread.send(UserMessageId::new(), ["abc"], cx)
455 })
456 .unwrap();
457 cx.run_until_parked();
458 let mut pending_completions = fake_model.pending_completions();
459 assert_eq!(
460 pending_completions.len(),
461 1,
462 "unexpected pending completions: {:?}",
463 pending_completions
464 );
465
466 let pending_completion = pending_completions.pop().unwrap();
467 assert_eq!(pending_completion.messages[0].role, Role::System);
468
469 let system_message = &pending_completion.messages[0];
470 let system_prompt = system_message.content[0].to_str().unwrap();
471 assert!(
472 !system_prompt.contains("## Tool Use"),
473 "unexpected system message: {:?}",
474 system_message
475 );
476 assert!(
477 !system_prompt.contains("## Fixing Diagnostics"),
478 "unexpected system message: {:?}",
479 system_message
480 );
481}
482
483#[gpui::test]
484async fn test_prompt_caching(cx: &mut TestAppContext) {
485 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
486 let fake_model = model.as_fake();
487
488 // Send initial user message and verify it's cached
489 thread
490 .update(cx, |thread, cx| {
491 thread.send(UserMessageId::new(), ["Message 1"], cx)
492 })
493 .unwrap();
494 cx.run_until_parked();
495
496 let completion = fake_model.pending_completions().pop().unwrap();
497 assert_eq!(
498 completion.messages[1..],
499 vec![LanguageModelRequestMessage {
500 role: Role::User,
501 content: vec!["Message 1".into()],
502 cache: true,
503 reasoning_details: None,
504 }]
505 );
506 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
507 "Response to Message 1".into(),
508 ));
509 fake_model.end_last_completion_stream();
510 cx.run_until_parked();
511
512 // Send another user message and verify only the latest is cached
513 thread
514 .update(cx, |thread, cx| {
515 thread.send(UserMessageId::new(), ["Message 2"], cx)
516 })
517 .unwrap();
518 cx.run_until_parked();
519
520 let completion = fake_model.pending_completions().pop().unwrap();
521 assert_eq!(
522 completion.messages[1..],
523 vec![
524 LanguageModelRequestMessage {
525 role: Role::User,
526 content: vec!["Message 1".into()],
527 cache: false,
528 reasoning_details: None,
529 },
530 LanguageModelRequestMessage {
531 role: Role::Assistant,
532 content: vec!["Response to Message 1".into()],
533 cache: false,
534 reasoning_details: None,
535 },
536 LanguageModelRequestMessage {
537 role: Role::User,
538 content: vec!["Message 2".into()],
539 cache: true,
540 reasoning_details: None,
541 }
542 ]
543 );
544 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
545 "Response to Message 2".into(),
546 ));
547 fake_model.end_last_completion_stream();
548 cx.run_until_parked();
549
550 // Simulate a tool call and verify that the latest tool result is cached
551 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
552 thread
553 .update(cx, |thread, cx| {
554 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
555 })
556 .unwrap();
557 cx.run_until_parked();
558
559 let tool_use = LanguageModelToolUse {
560 id: "tool_1".into(),
561 name: EchoTool::name().into(),
562 raw_input: json!({"text": "test"}).to_string(),
563 input: json!({"text": "test"}),
564 is_input_complete: true,
565 thought_signature: None,
566 };
567 fake_model
568 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
569 fake_model.end_last_completion_stream();
570 cx.run_until_parked();
571
572 let completion = fake_model.pending_completions().pop().unwrap();
573 let tool_result = LanguageModelToolResult {
574 tool_use_id: "tool_1".into(),
575 tool_name: EchoTool::name().into(),
576 is_error: false,
577 content: "test".into(),
578 output: Some("test".into()),
579 };
580 assert_eq!(
581 completion.messages[1..],
582 vec![
583 LanguageModelRequestMessage {
584 role: Role::User,
585 content: vec!["Message 1".into()],
586 cache: false,
587 reasoning_details: None,
588 },
589 LanguageModelRequestMessage {
590 role: Role::Assistant,
591 content: vec!["Response to Message 1".into()],
592 cache: false,
593 reasoning_details: None,
594 },
595 LanguageModelRequestMessage {
596 role: Role::User,
597 content: vec!["Message 2".into()],
598 cache: false,
599 reasoning_details: None,
600 },
601 LanguageModelRequestMessage {
602 role: Role::Assistant,
603 content: vec!["Response to Message 2".into()],
604 cache: false,
605 reasoning_details: None,
606 },
607 LanguageModelRequestMessage {
608 role: Role::User,
609 content: vec!["Use the echo tool".into()],
610 cache: false,
611 reasoning_details: None,
612 },
613 LanguageModelRequestMessage {
614 role: Role::Assistant,
615 content: vec![MessageContent::ToolUse(tool_use)],
616 cache: false,
617 reasoning_details: None,
618 },
619 LanguageModelRequestMessage {
620 role: Role::User,
621 content: vec![MessageContent::ToolResult(tool_result)],
622 cache: true,
623 reasoning_details: None,
624 }
625 ]
626 );
627}
628
629#[gpui::test]
630#[cfg_attr(not(feature = "e2e"), ignore)]
631async fn test_basic_tool_calls(cx: &mut TestAppContext) {
632 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
633
634 // Test a tool call that's likely to complete *before* streaming stops.
635 let events = thread
636 .update(cx, |thread, cx| {
637 thread.add_tool(EchoTool);
638 thread.send(
639 UserMessageId::new(),
640 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
641 cx,
642 )
643 })
644 .unwrap()
645 .collect()
646 .await;
647 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
648
649 // Test a tool calls that's likely to complete *after* streaming stops.
650 let events = thread
651 .update(cx, |thread, cx| {
652 thread.remove_tool(&EchoTool::name());
653 thread.add_tool(DelayTool);
654 thread.send(
655 UserMessageId::new(),
656 [
657 "Now call the delay tool with 200ms.",
658 "When the timer goes off, then you echo the output of the tool.",
659 ],
660 cx,
661 )
662 })
663 .unwrap()
664 .collect()
665 .await;
666 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
667 thread.update(cx, |thread, _cx| {
668 assert!(
669 thread
670 .last_message()
671 .unwrap()
672 .as_agent_message()
673 .unwrap()
674 .content
675 .iter()
676 .any(|content| {
677 if let AgentMessageContent::Text(text) = content {
678 text.contains("Ding")
679 } else {
680 false
681 }
682 }),
683 "{}",
684 thread.to_markdown()
685 );
686 });
687}
688
689#[gpui::test]
690#[cfg_attr(not(feature = "e2e"), ignore)]
691async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
692 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
693
694 // Test a tool call that's likely to complete *before* streaming stops.
695 let mut events = thread
696 .update(cx, |thread, cx| {
697 thread.add_tool(WordListTool);
698 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
699 })
700 .unwrap();
701
702 let mut saw_partial_tool_use = false;
703 while let Some(event) = events.next().await {
704 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
705 thread.update(cx, |thread, _cx| {
706 // Look for a tool use in the thread's last message
707 let message = thread.last_message().unwrap();
708 let agent_message = message.as_agent_message().unwrap();
709 let last_content = agent_message.content.last().unwrap();
710 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
711 assert_eq!(last_tool_use.name.as_ref(), "word_list");
712 if tool_call.status == acp::ToolCallStatus::Pending {
713 if !last_tool_use.is_input_complete
714 && last_tool_use.input.get("g").is_none()
715 {
716 saw_partial_tool_use = true;
717 }
718 } else {
719 last_tool_use
720 .input
721 .get("a")
722 .expect("'a' has streamed because input is now complete");
723 last_tool_use
724 .input
725 .get("g")
726 .expect("'g' has streamed because input is now complete");
727 }
728 } else {
729 panic!("last content should be a tool use");
730 }
731 });
732 }
733 }
734
735 assert!(
736 saw_partial_tool_use,
737 "should see at least one partially streamed tool use in the history"
738 );
739}
740
741#[gpui::test]
742async fn test_tool_authorization(cx: &mut TestAppContext) {
743 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
744 let fake_model = model.as_fake();
745
746 let mut events = thread
747 .update(cx, |thread, cx| {
748 thread.add_tool(ToolRequiringPermission);
749 thread.send(UserMessageId::new(), ["abc"], cx)
750 })
751 .unwrap();
752 cx.run_until_parked();
753 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
754 LanguageModelToolUse {
755 id: "tool_id_1".into(),
756 name: ToolRequiringPermission::name().into(),
757 raw_input: "{}".into(),
758 input: json!({}),
759 is_input_complete: true,
760 thought_signature: None,
761 },
762 ));
763 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
764 LanguageModelToolUse {
765 id: "tool_id_2".into(),
766 name: ToolRequiringPermission::name().into(),
767 raw_input: "{}".into(),
768 input: json!({}),
769 is_input_complete: true,
770 thought_signature: None,
771 },
772 ));
773 fake_model.end_last_completion_stream();
774 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
775 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
776
777 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
778 tool_call_auth_1
779 .response
780 .send(acp::PermissionOptionId::new("allow"))
781 .unwrap();
782 cx.run_until_parked();
783
784 // Reject the second - send "deny" option_id directly since Deny is now a button
785 tool_call_auth_2
786 .response
787 .send(acp::PermissionOptionId::new("deny"))
788 .unwrap();
789 cx.run_until_parked();
790
791 let completion = fake_model.pending_completions().pop().unwrap();
792 let message = completion.messages.last().unwrap();
793 assert_eq!(
794 message.content,
795 vec![
796 language_model::MessageContent::ToolResult(LanguageModelToolResult {
797 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
798 tool_name: ToolRequiringPermission::name().into(),
799 is_error: false,
800 content: "Allowed".into(),
801 output: Some("Allowed".into())
802 }),
803 language_model::MessageContent::ToolResult(LanguageModelToolResult {
804 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
805 tool_name: ToolRequiringPermission::name().into(),
806 is_error: true,
807 content: "Permission to run tool denied by user".into(),
808 output: Some("Permission to run tool denied by user".into())
809 })
810 ]
811 );
812
813 // Simulate yet another tool call.
814 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
815 LanguageModelToolUse {
816 id: "tool_id_3".into(),
817 name: ToolRequiringPermission::name().into(),
818 raw_input: "{}".into(),
819 input: json!({}),
820 is_input_complete: true,
821 thought_signature: None,
822 },
823 ));
824 fake_model.end_last_completion_stream();
825
826 // Respond by always allowing tools - send transformed option_id
827 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
828 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
829 tool_call_auth_3
830 .response
831 .send(acp::PermissionOptionId::new(
832 "always_allow:tool_requiring_permission",
833 ))
834 .unwrap();
835 cx.run_until_parked();
836 let completion = fake_model.pending_completions().pop().unwrap();
837 let message = completion.messages.last().unwrap();
838 assert_eq!(
839 message.content,
840 vec![language_model::MessageContent::ToolResult(
841 LanguageModelToolResult {
842 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
843 tool_name: ToolRequiringPermission::name().into(),
844 is_error: false,
845 content: "Allowed".into(),
846 output: Some("Allowed".into())
847 }
848 )]
849 );
850
851 // Simulate a final tool call, ensuring we don't trigger authorization.
852 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
853 LanguageModelToolUse {
854 id: "tool_id_4".into(),
855 name: ToolRequiringPermission::name().into(),
856 raw_input: "{}".into(),
857 input: json!({}),
858 is_input_complete: true,
859 thought_signature: None,
860 },
861 ));
862 fake_model.end_last_completion_stream();
863 cx.run_until_parked();
864 let completion = fake_model.pending_completions().pop().unwrap();
865 let message = completion.messages.last().unwrap();
866 assert_eq!(
867 message.content,
868 vec![language_model::MessageContent::ToolResult(
869 LanguageModelToolResult {
870 tool_use_id: "tool_id_4".into(),
871 tool_name: ToolRequiringPermission::name().into(),
872 is_error: false,
873 content: "Allowed".into(),
874 output: Some("Allowed".into())
875 }
876 )]
877 );
878}
879
880#[gpui::test]
881async fn test_tool_hallucination(cx: &mut TestAppContext) {
882 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
883 let fake_model = model.as_fake();
884
885 let mut events = thread
886 .update(cx, |thread, cx| {
887 thread.send(UserMessageId::new(), ["abc"], cx)
888 })
889 .unwrap();
890 cx.run_until_parked();
891 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
892 LanguageModelToolUse {
893 id: "tool_id_1".into(),
894 name: "nonexistent_tool".into(),
895 raw_input: "{}".into(),
896 input: json!({}),
897 is_input_complete: true,
898 thought_signature: None,
899 },
900 ));
901 fake_model.end_last_completion_stream();
902
903 let tool_call = expect_tool_call(&mut events).await;
904 assert_eq!(tool_call.title, "nonexistent_tool");
905 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
906 let update = expect_tool_call_update_fields(&mut events).await;
907 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
908}
909
910async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
911 let event = events
912 .next()
913 .await
914 .expect("no tool call authorization event received")
915 .unwrap();
916 match event {
917 ThreadEvent::ToolCall(tool_call) => tool_call,
918 event => {
919 panic!("Unexpected event {event:?}");
920 }
921 }
922}
923
924async fn expect_tool_call_update_fields(
925 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
926) -> acp::ToolCallUpdate {
927 let event = events
928 .next()
929 .await
930 .expect("no tool call authorization event received")
931 .unwrap();
932 match event {
933 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
934 event => {
935 panic!("Unexpected event {event:?}");
936 }
937 }
938}
939
940async fn next_tool_call_authorization(
941 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
942) -> ToolCallAuthorization {
943 loop {
944 let event = events
945 .next()
946 .await
947 .expect("no tool call authorization event received")
948 .unwrap();
949 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
950 let permission_kinds = tool_call_authorization
951 .options
952 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
953 .map(|option| option.kind);
954 let allow_once = tool_call_authorization
955 .options
956 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
957 .map(|option| option.kind);
958
959 assert_eq!(
960 permission_kinds,
961 Some(acp::PermissionOptionKind::AllowAlways)
962 );
963 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
964 return tool_call_authorization;
965 }
966 }
967}
968
969#[test]
970fn test_permission_options_terminal_with_pattern() {
971 let permission_options =
972 ToolPermissionContext::new("terminal", "cargo build --release").build_permission_options();
973
974 let PermissionOptions::Dropdown(choices) = permission_options else {
975 panic!("Expected dropdown permission options");
976 };
977
978 assert_eq!(choices.len(), 3);
979 let labels: Vec<&str> = choices
980 .iter()
981 .map(|choice| choice.allow.name.as_ref())
982 .collect();
983 assert!(labels.contains(&"Always for terminal"));
984 assert!(labels.contains(&"Always for `cargo` commands"));
985 assert!(labels.contains(&"Only this time"));
986}
987
988#[test]
989fn test_permission_options_edit_file_with_path_pattern() {
990 let permission_options =
991 ToolPermissionContext::new("edit_file", "src/main.rs").build_permission_options();
992
993 let PermissionOptions::Dropdown(choices) = permission_options else {
994 panic!("Expected dropdown permission options");
995 };
996
997 let labels: Vec<&str> = choices
998 .iter()
999 .map(|choice| choice.allow.name.as_ref())
1000 .collect();
1001 assert!(labels.contains(&"Always for edit file"));
1002 assert!(labels.contains(&"Always for `src/`"));
1003}
1004
1005#[test]
1006fn test_permission_options_fetch_with_domain_pattern() {
1007 let permission_options =
1008 ToolPermissionContext::new("fetch", "https://docs.rs/gpui").build_permission_options();
1009
1010 let PermissionOptions::Dropdown(choices) = permission_options else {
1011 panic!("Expected dropdown permission options");
1012 };
1013
1014 let labels: Vec<&str> = choices
1015 .iter()
1016 .map(|choice| choice.allow.name.as_ref())
1017 .collect();
1018 assert!(labels.contains(&"Always for fetch"));
1019 assert!(labels.contains(&"Always for `docs.rs`"));
1020}
1021
1022#[test]
1023fn test_permission_options_without_pattern() {
1024 let permission_options = ToolPermissionContext::new("terminal", "./deploy.sh --production")
1025 .build_permission_options();
1026
1027 let PermissionOptions::Dropdown(choices) = permission_options else {
1028 panic!("Expected dropdown permission options");
1029 };
1030
1031 assert_eq!(choices.len(), 2);
1032 let labels: Vec<&str> = choices
1033 .iter()
1034 .map(|choice| choice.allow.name.as_ref())
1035 .collect();
1036 assert!(labels.contains(&"Always for terminal"));
1037 assert!(labels.contains(&"Only this time"));
1038 assert!(!labels.iter().any(|label| label.contains("commands")));
1039}
1040
1041#[test]
1042fn test_permission_option_ids_for_terminal() {
1043 let permission_options =
1044 ToolPermissionContext::new("terminal", "cargo build --release").build_permission_options();
1045
1046 let PermissionOptions::Dropdown(choices) = permission_options else {
1047 panic!("Expected dropdown permission options");
1048 };
1049
1050 let allow_ids: Vec<String> = choices
1051 .iter()
1052 .map(|choice| choice.allow.option_id.0.to_string())
1053 .collect();
1054 let deny_ids: Vec<String> = choices
1055 .iter()
1056 .map(|choice| choice.deny.option_id.0.to_string())
1057 .collect();
1058
1059 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1060 assert!(allow_ids.contains(&"allow".to_string()));
1061 assert!(
1062 allow_ids
1063 .iter()
1064 .any(|id| id.starts_with("always_allow_pattern:terminal:")),
1065 "Missing allow pattern option"
1066 );
1067
1068 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1069 assert!(deny_ids.contains(&"deny".to_string()));
1070 assert!(
1071 deny_ids
1072 .iter()
1073 .any(|id| id.starts_with("always_deny_pattern:terminal:")),
1074 "Missing deny pattern option"
1075 );
1076}
1077
1078#[gpui::test]
1079#[cfg_attr(not(feature = "e2e"), ignore)]
1080async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1081 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1082
1083 // Test concurrent tool calls with different delay times
1084 let events = thread
1085 .update(cx, |thread, cx| {
1086 thread.add_tool(DelayTool);
1087 thread.send(
1088 UserMessageId::new(),
1089 [
1090 "Call the delay tool twice in the same message.",
1091 "Once with 100ms. Once with 300ms.",
1092 "When both timers are complete, describe the outputs.",
1093 ],
1094 cx,
1095 )
1096 })
1097 .unwrap()
1098 .collect()
1099 .await;
1100
1101 let stop_reasons = stop_events(events);
1102 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1103
1104 thread.update(cx, |thread, _cx| {
1105 let last_message = thread.last_message().unwrap();
1106 let agent_message = last_message.as_agent_message().unwrap();
1107 let text = agent_message
1108 .content
1109 .iter()
1110 .filter_map(|content| {
1111 if let AgentMessageContent::Text(text) = content {
1112 Some(text.as_str())
1113 } else {
1114 None
1115 }
1116 })
1117 .collect::<String>();
1118
1119 assert!(text.contains("Ding"));
1120 });
1121}
1122
1123#[gpui::test]
1124async fn test_profiles(cx: &mut TestAppContext) {
1125 let ThreadTest {
1126 model, thread, fs, ..
1127 } = setup(cx, TestModel::Fake).await;
1128 let fake_model = model.as_fake();
1129
1130 thread.update(cx, |thread, _cx| {
1131 thread.add_tool(DelayTool);
1132 thread.add_tool(EchoTool);
1133 thread.add_tool(InfiniteTool);
1134 });
1135
1136 // Override profiles and wait for settings to be loaded.
1137 fs.insert_file(
1138 paths::settings_file(),
1139 json!({
1140 "agent": {
1141 "profiles": {
1142 "test-1": {
1143 "name": "Test Profile 1",
1144 "tools": {
1145 EchoTool::name(): true,
1146 DelayTool::name(): true,
1147 }
1148 },
1149 "test-2": {
1150 "name": "Test Profile 2",
1151 "tools": {
1152 InfiniteTool::name(): true,
1153 }
1154 }
1155 }
1156 }
1157 })
1158 .to_string()
1159 .into_bytes(),
1160 )
1161 .await;
1162 cx.run_until_parked();
1163
1164 // Test that test-1 profile (default) has echo and delay tools
1165 thread
1166 .update(cx, |thread, cx| {
1167 thread.set_profile(AgentProfileId("test-1".into()), cx);
1168 thread.send(UserMessageId::new(), ["test"], cx)
1169 })
1170 .unwrap();
1171 cx.run_until_parked();
1172
1173 let mut pending_completions = fake_model.pending_completions();
1174 assert_eq!(pending_completions.len(), 1);
1175 let completion = pending_completions.pop().unwrap();
1176 let tool_names: Vec<String> = completion
1177 .tools
1178 .iter()
1179 .map(|tool| tool.name.clone())
1180 .collect();
1181 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
1182 fake_model.end_last_completion_stream();
1183
1184 // Switch to test-2 profile, and verify that it has only the infinite tool.
1185 thread
1186 .update(cx, |thread, cx| {
1187 thread.set_profile(AgentProfileId("test-2".into()), cx);
1188 thread.send(UserMessageId::new(), ["test2"], cx)
1189 })
1190 .unwrap();
1191 cx.run_until_parked();
1192 let mut pending_completions = fake_model.pending_completions();
1193 assert_eq!(pending_completions.len(), 1);
1194 let completion = pending_completions.pop().unwrap();
1195 let tool_names: Vec<String> = completion
1196 .tools
1197 .iter()
1198 .map(|tool| tool.name.clone())
1199 .collect();
1200 assert_eq!(tool_names, vec![InfiniteTool::name()]);
1201}
1202
1203#[gpui::test]
1204async fn test_mcp_tools(cx: &mut TestAppContext) {
1205 let ThreadTest {
1206 model,
1207 thread,
1208 context_server_store,
1209 fs,
1210 ..
1211 } = setup(cx, TestModel::Fake).await;
1212 let fake_model = model.as_fake();
1213
1214 // Override profiles and wait for settings to be loaded.
1215 fs.insert_file(
1216 paths::settings_file(),
1217 json!({
1218 "agent": {
1219 "always_allow_tool_actions": true,
1220 "profiles": {
1221 "test": {
1222 "name": "Test Profile",
1223 "enable_all_context_servers": true,
1224 "tools": {
1225 EchoTool::name(): true,
1226 }
1227 },
1228 }
1229 }
1230 })
1231 .to_string()
1232 .into_bytes(),
1233 )
1234 .await;
1235 cx.run_until_parked();
1236 thread.update(cx, |thread, cx| {
1237 thread.set_profile(AgentProfileId("test".into()), cx)
1238 });
1239
1240 let mut mcp_tool_calls = setup_context_server(
1241 "test_server",
1242 vec![context_server::types::Tool {
1243 name: "echo".into(),
1244 description: None,
1245 input_schema: serde_json::to_value(EchoTool::input_schema(
1246 LanguageModelToolSchemaFormat::JsonSchema,
1247 ))
1248 .unwrap(),
1249 output_schema: None,
1250 annotations: None,
1251 }],
1252 &context_server_store,
1253 cx,
1254 );
1255
1256 let events = thread.update(cx, |thread, cx| {
1257 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1258 });
1259 cx.run_until_parked();
1260
1261 // Simulate the model calling the MCP tool.
1262 let completion = fake_model.pending_completions().pop().unwrap();
1263 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1264 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1265 LanguageModelToolUse {
1266 id: "tool_1".into(),
1267 name: "echo".into(),
1268 raw_input: json!({"text": "test"}).to_string(),
1269 input: json!({"text": "test"}),
1270 is_input_complete: true,
1271 thought_signature: None,
1272 },
1273 ));
1274 fake_model.end_last_completion_stream();
1275 cx.run_until_parked();
1276
1277 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1278 assert_eq!(tool_call_params.name, "echo");
1279 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1280 tool_call_response
1281 .send(context_server::types::CallToolResponse {
1282 content: vec![context_server::types::ToolResponseContent::Text {
1283 text: "test".into(),
1284 }],
1285 is_error: None,
1286 meta: None,
1287 structured_content: None,
1288 })
1289 .unwrap();
1290 cx.run_until_parked();
1291
1292 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1293 fake_model.send_last_completion_stream_text_chunk("Done!");
1294 fake_model.end_last_completion_stream();
1295 events.collect::<Vec<_>>().await;
1296
1297 // Send again after adding the echo tool, ensuring the name collision is resolved.
1298 let events = thread.update(cx, |thread, cx| {
1299 thread.add_tool(EchoTool);
1300 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1301 });
1302 cx.run_until_parked();
1303 let completion = fake_model.pending_completions().pop().unwrap();
1304 assert_eq!(
1305 tool_names_for_completion(&completion),
1306 vec!["echo", "test_server_echo"]
1307 );
1308 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1309 LanguageModelToolUse {
1310 id: "tool_2".into(),
1311 name: "test_server_echo".into(),
1312 raw_input: json!({"text": "mcp"}).to_string(),
1313 input: json!({"text": "mcp"}),
1314 is_input_complete: true,
1315 thought_signature: None,
1316 },
1317 ));
1318 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1319 LanguageModelToolUse {
1320 id: "tool_3".into(),
1321 name: "echo".into(),
1322 raw_input: json!({"text": "native"}).to_string(),
1323 input: json!({"text": "native"}),
1324 is_input_complete: true,
1325 thought_signature: None,
1326 },
1327 ));
1328 fake_model.end_last_completion_stream();
1329 cx.run_until_parked();
1330
1331 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1332 assert_eq!(tool_call_params.name, "echo");
1333 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1334 tool_call_response
1335 .send(context_server::types::CallToolResponse {
1336 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1337 is_error: None,
1338 meta: None,
1339 structured_content: None,
1340 })
1341 .unwrap();
1342 cx.run_until_parked();
1343
1344 // Ensure the tool results were inserted with the correct names.
1345 let completion = fake_model.pending_completions().pop().unwrap();
1346 assert_eq!(
1347 completion.messages.last().unwrap().content,
1348 vec![
1349 MessageContent::ToolResult(LanguageModelToolResult {
1350 tool_use_id: "tool_3".into(),
1351 tool_name: "echo".into(),
1352 is_error: false,
1353 content: "native".into(),
1354 output: Some("native".into()),
1355 },),
1356 MessageContent::ToolResult(LanguageModelToolResult {
1357 tool_use_id: "tool_2".into(),
1358 tool_name: "test_server_echo".into(),
1359 is_error: false,
1360 content: "mcp".into(),
1361 output: Some("mcp".into()),
1362 },),
1363 ]
1364 );
1365 fake_model.end_last_completion_stream();
1366 events.collect::<Vec<_>>().await;
1367}
1368
1369#[gpui::test]
1370async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1371 let ThreadTest {
1372 model,
1373 thread,
1374 context_server_store,
1375 fs,
1376 ..
1377 } = setup(cx, TestModel::Fake).await;
1378 let fake_model = model.as_fake();
1379
1380 // Set up a profile with all tools enabled
1381 fs.insert_file(
1382 paths::settings_file(),
1383 json!({
1384 "agent": {
1385 "profiles": {
1386 "test": {
1387 "name": "Test Profile",
1388 "enable_all_context_servers": true,
1389 "tools": {
1390 EchoTool::name(): true,
1391 DelayTool::name(): true,
1392 WordListTool::name(): true,
1393 ToolRequiringPermission::name(): true,
1394 InfiniteTool::name(): true,
1395 }
1396 },
1397 }
1398 }
1399 })
1400 .to_string()
1401 .into_bytes(),
1402 )
1403 .await;
1404 cx.run_until_parked();
1405
1406 thread.update(cx, |thread, cx| {
1407 thread.set_profile(AgentProfileId("test".into()), cx);
1408 thread.add_tool(EchoTool);
1409 thread.add_tool(DelayTool);
1410 thread.add_tool(WordListTool);
1411 thread.add_tool(ToolRequiringPermission);
1412 thread.add_tool(InfiniteTool);
1413 });
1414
1415 // Set up multiple context servers with some overlapping tool names
1416 let _server1_calls = setup_context_server(
1417 "xxx",
1418 vec![
1419 context_server::types::Tool {
1420 name: "echo".into(), // Conflicts with native EchoTool
1421 description: None,
1422 input_schema: serde_json::to_value(EchoTool::input_schema(
1423 LanguageModelToolSchemaFormat::JsonSchema,
1424 ))
1425 .unwrap(),
1426 output_schema: None,
1427 annotations: None,
1428 },
1429 context_server::types::Tool {
1430 name: "unique_tool_1".into(),
1431 description: None,
1432 input_schema: json!({"type": "object", "properties": {}}),
1433 output_schema: None,
1434 annotations: None,
1435 },
1436 ],
1437 &context_server_store,
1438 cx,
1439 );
1440
1441 let _server2_calls = setup_context_server(
1442 "yyy",
1443 vec![
1444 context_server::types::Tool {
1445 name: "echo".into(), // Also conflicts with native EchoTool
1446 description: None,
1447 input_schema: serde_json::to_value(EchoTool::input_schema(
1448 LanguageModelToolSchemaFormat::JsonSchema,
1449 ))
1450 .unwrap(),
1451 output_schema: None,
1452 annotations: None,
1453 },
1454 context_server::types::Tool {
1455 name: "unique_tool_2".into(),
1456 description: None,
1457 input_schema: json!({"type": "object", "properties": {}}),
1458 output_schema: None,
1459 annotations: None,
1460 },
1461 context_server::types::Tool {
1462 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1463 description: None,
1464 input_schema: json!({"type": "object", "properties": {}}),
1465 output_schema: None,
1466 annotations: None,
1467 },
1468 context_server::types::Tool {
1469 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1470 description: None,
1471 input_schema: json!({"type": "object", "properties": {}}),
1472 output_schema: None,
1473 annotations: None,
1474 },
1475 ],
1476 &context_server_store,
1477 cx,
1478 );
1479 let _server3_calls = setup_context_server(
1480 "zzz",
1481 vec![
1482 context_server::types::Tool {
1483 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1484 description: None,
1485 input_schema: json!({"type": "object", "properties": {}}),
1486 output_schema: None,
1487 annotations: None,
1488 },
1489 context_server::types::Tool {
1490 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1491 description: None,
1492 input_schema: json!({"type": "object", "properties": {}}),
1493 output_schema: None,
1494 annotations: None,
1495 },
1496 context_server::types::Tool {
1497 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1498 description: None,
1499 input_schema: json!({"type": "object", "properties": {}}),
1500 output_schema: None,
1501 annotations: None,
1502 },
1503 ],
1504 &context_server_store,
1505 cx,
1506 );
1507
1508 thread
1509 .update(cx, |thread, cx| {
1510 thread.send(UserMessageId::new(), ["Go"], cx)
1511 })
1512 .unwrap();
1513 cx.run_until_parked();
1514 let completion = fake_model.pending_completions().pop().unwrap();
1515 assert_eq!(
1516 tool_names_for_completion(&completion),
1517 vec![
1518 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1519 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1520 "delay",
1521 "echo",
1522 "infinite",
1523 "tool_requiring_permission",
1524 "unique_tool_1",
1525 "unique_tool_2",
1526 "word_list",
1527 "xxx_echo",
1528 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1529 "yyy_echo",
1530 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1531 ]
1532 );
1533}
1534
1535#[gpui::test]
1536#[cfg_attr(not(feature = "e2e"), ignore)]
1537async fn test_cancellation(cx: &mut TestAppContext) {
1538 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1539
1540 let mut events = thread
1541 .update(cx, |thread, cx| {
1542 thread.add_tool(InfiniteTool);
1543 thread.add_tool(EchoTool);
1544 thread.send(
1545 UserMessageId::new(),
1546 ["Call the echo tool, then call the infinite tool, then explain their output"],
1547 cx,
1548 )
1549 })
1550 .unwrap();
1551
1552 // Wait until both tools are called.
1553 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1554 let mut echo_id = None;
1555 let mut echo_completed = false;
1556 while let Some(event) = events.next().await {
1557 match event.unwrap() {
1558 ThreadEvent::ToolCall(tool_call) => {
1559 assert_eq!(tool_call.title, expected_tools.remove(0));
1560 if tool_call.title == "Echo" {
1561 echo_id = Some(tool_call.tool_call_id);
1562 }
1563 }
1564 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1565 acp::ToolCallUpdate {
1566 tool_call_id,
1567 fields:
1568 acp::ToolCallUpdateFields {
1569 status: Some(acp::ToolCallStatus::Completed),
1570 ..
1571 },
1572 ..
1573 },
1574 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1575 echo_completed = true;
1576 }
1577 _ => {}
1578 }
1579
1580 if expected_tools.is_empty() && echo_completed {
1581 break;
1582 }
1583 }
1584
1585 // Cancel the current send and ensure that the event stream is closed, even
1586 // if one of the tools is still running.
1587 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1588 let events = events.collect::<Vec<_>>().await;
1589 let last_event = events.last();
1590 assert!(
1591 matches!(
1592 last_event,
1593 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1594 ),
1595 "unexpected event {last_event:?}"
1596 );
1597
1598 // Ensure we can still send a new message after cancellation.
1599 let events = thread
1600 .update(cx, |thread, cx| {
1601 thread.send(
1602 UserMessageId::new(),
1603 ["Testing: reply with 'Hello' then stop."],
1604 cx,
1605 )
1606 })
1607 .unwrap()
1608 .collect::<Vec<_>>()
1609 .await;
1610 thread.update(cx, |thread, _cx| {
1611 let message = thread.last_message().unwrap();
1612 let agent_message = message.as_agent_message().unwrap();
1613 assert_eq!(
1614 agent_message.content,
1615 vec![AgentMessageContent::Text("Hello".to_string())]
1616 );
1617 });
1618 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1619}
1620
1621#[gpui::test]
1622async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1623 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1624 always_allow_tools(cx);
1625 let fake_model = model.as_fake();
1626
1627 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1628 let environment = Rc::new(FakeThreadEnvironment {
1629 handle: handle.clone(),
1630 });
1631
1632 let mut events = thread
1633 .update(cx, |thread, cx| {
1634 thread.add_tool(crate::TerminalTool::new(
1635 thread.project().clone(),
1636 environment,
1637 ));
1638 thread.send(UserMessageId::new(), ["run a command"], cx)
1639 })
1640 .unwrap();
1641
1642 cx.run_until_parked();
1643
1644 // Simulate the model calling the terminal tool
1645 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1646 LanguageModelToolUse {
1647 id: "terminal_tool_1".into(),
1648 name: "terminal".into(),
1649 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1650 input: json!({"command": "sleep 1000", "cd": "."}),
1651 is_input_complete: true,
1652 thought_signature: None,
1653 },
1654 ));
1655 fake_model.end_last_completion_stream();
1656
1657 // Wait for the terminal tool to start running
1658 wait_for_terminal_tool_started(&mut events, cx).await;
1659
1660 // Cancel the thread while the terminal is running
1661 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1662
1663 // Collect remaining events, driving the executor to let cancellation complete
1664 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1665
1666 // Verify the terminal was killed
1667 assert!(
1668 handle.was_killed(),
1669 "expected terminal handle to be killed on cancellation"
1670 );
1671
1672 // Verify we got a cancellation stop event
1673 assert_eq!(
1674 stop_events(remaining_events),
1675 vec![acp::StopReason::Cancelled],
1676 );
1677
1678 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1679 thread.update(cx, |thread, _cx| {
1680 let message = thread.last_message().unwrap();
1681 let agent_message = message.as_agent_message().unwrap();
1682
1683 let tool_use = agent_message
1684 .content
1685 .iter()
1686 .find_map(|content| match content {
1687 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1688 _ => None,
1689 })
1690 .expect("expected tool use in agent message");
1691
1692 let tool_result = agent_message
1693 .tool_results
1694 .get(&tool_use.id)
1695 .expect("expected tool result");
1696
1697 let result_text = match &tool_result.content {
1698 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
1699 _ => panic!("expected text content in tool result"),
1700 };
1701
1702 // "partial output" comes from FakeTerminalHandle's output field
1703 assert!(
1704 result_text.contains("partial output"),
1705 "expected tool result to contain terminal output, got: {result_text}"
1706 );
1707 // Match the actual format from process_content in terminal_tool.rs
1708 assert!(
1709 result_text.contains("The user stopped this command"),
1710 "expected tool result to indicate user stopped, got: {result_text}"
1711 );
1712 });
1713
1714 // Verify we can send a new message after cancellation
1715 verify_thread_recovery(&thread, &fake_model, cx).await;
1716}
1717
1718#[gpui::test]
1719async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
1720 // This test verifies that tools which properly handle cancellation via
1721 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
1722 // to cancellation and report that they were cancelled.
1723 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1724 always_allow_tools(cx);
1725 let fake_model = model.as_fake();
1726
1727 let (tool, was_cancelled) = CancellationAwareTool::new();
1728
1729 let mut events = thread
1730 .update(cx, |thread, cx| {
1731 thread.add_tool(tool);
1732 thread.send(
1733 UserMessageId::new(),
1734 ["call the cancellation aware tool"],
1735 cx,
1736 )
1737 })
1738 .unwrap();
1739
1740 cx.run_until_parked();
1741
1742 // Simulate the model calling the cancellation-aware tool
1743 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1744 LanguageModelToolUse {
1745 id: "cancellation_aware_1".into(),
1746 name: "cancellation_aware".into(),
1747 raw_input: r#"{}"#.into(),
1748 input: json!({}),
1749 is_input_complete: true,
1750 thought_signature: None,
1751 },
1752 ));
1753 fake_model.end_last_completion_stream();
1754
1755 cx.run_until_parked();
1756
1757 // Wait for the tool call to be reported
1758 let mut tool_started = false;
1759 let deadline = cx.executor().num_cpus() * 100;
1760 for _ in 0..deadline {
1761 cx.run_until_parked();
1762
1763 while let Some(Some(event)) = events.next().now_or_never() {
1764 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
1765 if tool_call.title == "Cancellation Aware Tool" {
1766 tool_started = true;
1767 break;
1768 }
1769 }
1770 }
1771
1772 if tool_started {
1773 break;
1774 }
1775
1776 cx.background_executor
1777 .timer(Duration::from_millis(10))
1778 .await;
1779 }
1780 assert!(tool_started, "expected cancellation aware tool to start");
1781
1782 // Cancel the thread and wait for it to complete
1783 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
1784
1785 // The cancel task should complete promptly because the tool handles cancellation
1786 let timeout = cx.background_executor.timer(Duration::from_secs(5));
1787 futures::select! {
1788 _ = cancel_task.fuse() => {}
1789 _ = timeout.fuse() => {
1790 panic!("cancel task timed out - tool did not respond to cancellation");
1791 }
1792 }
1793
1794 // Verify the tool detected cancellation via its flag
1795 assert!(
1796 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
1797 "tool should have detected cancellation via event_stream.cancelled_by_user()"
1798 );
1799
1800 // Collect remaining events
1801 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1802
1803 // Verify we got a cancellation stop event
1804 assert_eq!(
1805 stop_events(remaining_events),
1806 vec![acp::StopReason::Cancelled],
1807 );
1808
1809 // Verify we can send a new message after cancellation
1810 verify_thread_recovery(&thread, &fake_model, cx).await;
1811}
1812
1813/// Helper to verify thread can recover after cancellation by sending a simple message.
1814async fn verify_thread_recovery(
1815 thread: &Entity<Thread>,
1816 fake_model: &FakeLanguageModel,
1817 cx: &mut TestAppContext,
1818) {
1819 let events = thread
1820 .update(cx, |thread, cx| {
1821 thread.send(
1822 UserMessageId::new(),
1823 ["Testing: reply with 'Hello' then stop."],
1824 cx,
1825 )
1826 })
1827 .unwrap();
1828 cx.run_until_parked();
1829 fake_model.send_last_completion_stream_text_chunk("Hello");
1830 fake_model
1831 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1832 fake_model.end_last_completion_stream();
1833
1834 let events = events.collect::<Vec<_>>().await;
1835 thread.update(cx, |thread, _cx| {
1836 let message = thread.last_message().unwrap();
1837 let agent_message = message.as_agent_message().unwrap();
1838 assert_eq!(
1839 agent_message.content,
1840 vec![AgentMessageContent::Text("Hello".to_string())]
1841 );
1842 });
1843 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1844}
1845
1846/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
1847async fn wait_for_terminal_tool_started(
1848 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1849 cx: &mut TestAppContext,
1850) {
1851 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
1852 for _ in 0..deadline {
1853 cx.run_until_parked();
1854
1855 while let Some(Some(event)) = events.next().now_or_never() {
1856 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1857 update,
1858 ))) = &event
1859 {
1860 if update.fields.content.as_ref().is_some_and(|content| {
1861 content
1862 .iter()
1863 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1864 }) {
1865 return;
1866 }
1867 }
1868 }
1869
1870 cx.background_executor
1871 .timer(Duration::from_millis(10))
1872 .await;
1873 }
1874 panic!("terminal tool did not start within the expected time");
1875}
1876
1877/// Collects events until a Stop event is received, driving the executor to completion.
1878async fn collect_events_until_stop(
1879 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1880 cx: &mut TestAppContext,
1881) -> Vec<Result<ThreadEvent>> {
1882 let mut collected = Vec::new();
1883 let deadline = cx.executor().num_cpus() * 200;
1884
1885 for _ in 0..deadline {
1886 cx.executor().advance_clock(Duration::from_millis(10));
1887 cx.run_until_parked();
1888
1889 while let Some(Some(event)) = events.next().now_or_never() {
1890 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
1891 collected.push(event);
1892 if is_stop {
1893 return collected;
1894 }
1895 }
1896 }
1897 panic!(
1898 "did not receive Stop event within the expected time; collected {} events",
1899 collected.len()
1900 );
1901}
1902
1903#[gpui::test]
1904async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
1905 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1906 always_allow_tools(cx);
1907 let fake_model = model.as_fake();
1908
1909 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1910 let environment = Rc::new(FakeThreadEnvironment {
1911 handle: handle.clone(),
1912 });
1913
1914 let message_id = UserMessageId::new();
1915 let mut events = thread
1916 .update(cx, |thread, cx| {
1917 thread.add_tool(crate::TerminalTool::new(
1918 thread.project().clone(),
1919 environment,
1920 ));
1921 thread.send(message_id.clone(), ["run a command"], cx)
1922 })
1923 .unwrap();
1924
1925 cx.run_until_parked();
1926
1927 // Simulate the model calling the terminal tool
1928 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1929 LanguageModelToolUse {
1930 id: "terminal_tool_1".into(),
1931 name: "terminal".into(),
1932 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1933 input: json!({"command": "sleep 1000", "cd": "."}),
1934 is_input_complete: true,
1935 thought_signature: None,
1936 },
1937 ));
1938 fake_model.end_last_completion_stream();
1939
1940 // Wait for the terminal tool to start running
1941 wait_for_terminal_tool_started(&mut events, cx).await;
1942
1943 // Truncate the thread while the terminal is running
1944 thread
1945 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1946 .unwrap();
1947
1948 // Drive the executor to let cancellation complete
1949 let _ = collect_events_until_stop(&mut events, cx).await;
1950
1951 // Verify the terminal was killed
1952 assert!(
1953 handle.was_killed(),
1954 "expected terminal handle to be killed on truncate"
1955 );
1956
1957 // Verify the thread is empty after truncation
1958 thread.update(cx, |thread, _cx| {
1959 assert_eq!(
1960 thread.to_markdown(),
1961 "",
1962 "expected thread to be empty after truncating the only message"
1963 );
1964 });
1965
1966 // Verify we can send a new message after truncation
1967 verify_thread_recovery(&thread, &fake_model, cx).await;
1968}
1969
1970#[gpui::test]
1971async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
1972 // Tests that cancellation properly kills all running terminal tools when multiple are active.
1973 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1974 always_allow_tools(cx);
1975 let fake_model = model.as_fake();
1976
1977 let environment = Rc::new(MultiTerminalEnvironment::new());
1978
1979 let mut events = thread
1980 .update(cx, |thread, cx| {
1981 thread.add_tool(crate::TerminalTool::new(
1982 thread.project().clone(),
1983 environment.clone(),
1984 ));
1985 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
1986 })
1987 .unwrap();
1988
1989 cx.run_until_parked();
1990
1991 // Simulate the model calling two terminal tools
1992 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1993 LanguageModelToolUse {
1994 id: "terminal_tool_1".into(),
1995 name: "terminal".into(),
1996 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1997 input: json!({"command": "sleep 1000", "cd": "."}),
1998 is_input_complete: true,
1999 thought_signature: None,
2000 },
2001 ));
2002 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2003 LanguageModelToolUse {
2004 id: "terminal_tool_2".into(),
2005 name: "terminal".into(),
2006 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2007 input: json!({"command": "sleep 2000", "cd": "."}),
2008 is_input_complete: true,
2009 thought_signature: None,
2010 },
2011 ));
2012 fake_model.end_last_completion_stream();
2013
2014 // Wait for both terminal tools to start by counting terminal content updates
2015 let mut terminals_started = 0;
2016 let deadline = cx.executor().num_cpus() * 100;
2017 for _ in 0..deadline {
2018 cx.run_until_parked();
2019
2020 while let Some(Some(event)) = events.next().now_or_never() {
2021 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2022 update,
2023 ))) = &event
2024 {
2025 if update.fields.content.as_ref().is_some_and(|content| {
2026 content
2027 .iter()
2028 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2029 }) {
2030 terminals_started += 1;
2031 if terminals_started >= 2 {
2032 break;
2033 }
2034 }
2035 }
2036 }
2037 if terminals_started >= 2 {
2038 break;
2039 }
2040
2041 cx.background_executor
2042 .timer(Duration::from_millis(10))
2043 .await;
2044 }
2045 assert!(
2046 terminals_started >= 2,
2047 "expected 2 terminal tools to start, got {terminals_started}"
2048 );
2049
2050 // Cancel the thread while both terminals are running
2051 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2052
2053 // Collect remaining events
2054 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2055
2056 // Verify both terminal handles were killed
2057 let handles = environment.handles();
2058 assert_eq!(
2059 handles.len(),
2060 2,
2061 "expected 2 terminal handles to be created"
2062 );
2063 assert!(
2064 handles[0].was_killed(),
2065 "expected first terminal handle to be killed on cancellation"
2066 );
2067 assert!(
2068 handles[1].was_killed(),
2069 "expected second terminal handle to be killed on cancellation"
2070 );
2071
2072 // Verify we got a cancellation stop event
2073 assert_eq!(
2074 stop_events(remaining_events),
2075 vec![acp::StopReason::Cancelled],
2076 );
2077}
2078
2079#[gpui::test]
2080async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2081 // Tests that clicking the stop button on the terminal card (as opposed to the main
2082 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2083 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2084 always_allow_tools(cx);
2085 let fake_model = model.as_fake();
2086
2087 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2088 let environment = Rc::new(FakeThreadEnvironment {
2089 handle: handle.clone(),
2090 });
2091
2092 let mut events = thread
2093 .update(cx, |thread, cx| {
2094 thread.add_tool(crate::TerminalTool::new(
2095 thread.project().clone(),
2096 environment,
2097 ));
2098 thread.send(UserMessageId::new(), ["run a command"], cx)
2099 })
2100 .unwrap();
2101
2102 cx.run_until_parked();
2103
2104 // Simulate the model calling the terminal tool
2105 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2106 LanguageModelToolUse {
2107 id: "terminal_tool_1".into(),
2108 name: "terminal".into(),
2109 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2110 input: json!({"command": "sleep 1000", "cd": "."}),
2111 is_input_complete: true,
2112 thought_signature: None,
2113 },
2114 ));
2115 fake_model.end_last_completion_stream();
2116
2117 // Wait for the terminal tool to start running
2118 wait_for_terminal_tool_started(&mut events, cx).await;
2119
2120 // Simulate user clicking stop on the terminal card itself.
2121 // This sets the flag and signals exit (simulating what the real UI would do).
2122 handle.set_stopped_by_user(true);
2123 handle.killed.store(true, Ordering::SeqCst);
2124 handle.signal_exit();
2125
2126 // Wait for the tool to complete
2127 cx.run_until_parked();
2128
2129 // The thread continues after tool completion - simulate the model ending its turn
2130 fake_model
2131 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2132 fake_model.end_last_completion_stream();
2133
2134 // Collect remaining events
2135 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2136
2137 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2138 assert_eq!(
2139 stop_events(remaining_events),
2140 vec![acp::StopReason::EndTurn],
2141 );
2142
2143 // Verify the tool result indicates user stopped
2144 thread.update(cx, |thread, _cx| {
2145 let message = thread.last_message().unwrap();
2146 let agent_message = message.as_agent_message().unwrap();
2147
2148 let tool_use = agent_message
2149 .content
2150 .iter()
2151 .find_map(|content| match content {
2152 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2153 _ => None,
2154 })
2155 .expect("expected tool use in agent message");
2156
2157 let tool_result = agent_message
2158 .tool_results
2159 .get(&tool_use.id)
2160 .expect("expected tool result");
2161
2162 let result_text = match &tool_result.content {
2163 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2164 _ => panic!("expected text content in tool result"),
2165 };
2166
2167 assert!(
2168 result_text.contains("The user stopped this command"),
2169 "expected tool result to indicate user stopped, got: {result_text}"
2170 );
2171 });
2172}
2173
2174#[gpui::test]
2175async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2176 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2177 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2178 always_allow_tools(cx);
2179 let fake_model = model.as_fake();
2180
2181 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2182 let environment = Rc::new(FakeThreadEnvironment {
2183 handle: handle.clone(),
2184 });
2185
2186 let mut events = thread
2187 .update(cx, |thread, cx| {
2188 thread.add_tool(crate::TerminalTool::new(
2189 thread.project().clone(),
2190 environment,
2191 ));
2192 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2193 })
2194 .unwrap();
2195
2196 cx.run_until_parked();
2197
2198 // Simulate the model calling the terminal tool with a short timeout
2199 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2200 LanguageModelToolUse {
2201 id: "terminal_tool_1".into(),
2202 name: "terminal".into(),
2203 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2204 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2205 is_input_complete: true,
2206 thought_signature: None,
2207 },
2208 ));
2209 fake_model.end_last_completion_stream();
2210
2211 // Wait for the terminal tool to start running
2212 wait_for_terminal_tool_started(&mut events, cx).await;
2213
2214 // Advance clock past the timeout
2215 cx.executor().advance_clock(Duration::from_millis(200));
2216 cx.run_until_parked();
2217
2218 // The thread continues after tool completion - simulate the model ending its turn
2219 fake_model
2220 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2221 fake_model.end_last_completion_stream();
2222
2223 // Collect remaining events
2224 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2225
2226 // Verify the terminal was killed due to timeout
2227 assert!(
2228 handle.was_killed(),
2229 "expected terminal handle to be killed on timeout"
2230 );
2231
2232 // Verify we got an EndTurn (the tool completed, just with timeout)
2233 assert_eq!(
2234 stop_events(remaining_events),
2235 vec![acp::StopReason::EndTurn],
2236 );
2237
2238 // Verify the tool result indicates timeout, not user stopped
2239 thread.update(cx, |thread, _cx| {
2240 let message = thread.last_message().unwrap();
2241 let agent_message = message.as_agent_message().unwrap();
2242
2243 let tool_use = agent_message
2244 .content
2245 .iter()
2246 .find_map(|content| match content {
2247 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2248 _ => None,
2249 })
2250 .expect("expected tool use in agent message");
2251
2252 let tool_result = agent_message
2253 .tool_results
2254 .get(&tool_use.id)
2255 .expect("expected tool result");
2256
2257 let result_text = match &tool_result.content {
2258 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2259 _ => panic!("expected text content in tool result"),
2260 };
2261
2262 assert!(
2263 result_text.contains("timed out"),
2264 "expected tool result to indicate timeout, got: {result_text}"
2265 );
2266 assert!(
2267 !result_text.contains("The user stopped"),
2268 "tool result should not mention user stopped when it timed out, got: {result_text}"
2269 );
2270 });
2271}
2272
2273#[gpui::test]
2274async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2275 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2276 let fake_model = model.as_fake();
2277
2278 let events_1 = thread
2279 .update(cx, |thread, cx| {
2280 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2281 })
2282 .unwrap();
2283 cx.run_until_parked();
2284 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2285 cx.run_until_parked();
2286
2287 let events_2 = thread
2288 .update(cx, |thread, cx| {
2289 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2290 })
2291 .unwrap();
2292 cx.run_until_parked();
2293 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2294 fake_model
2295 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2296 fake_model.end_last_completion_stream();
2297
2298 let events_1 = events_1.collect::<Vec<_>>().await;
2299 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2300 let events_2 = events_2.collect::<Vec<_>>().await;
2301 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2302}
2303
2304#[gpui::test]
2305async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2306 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2307 let fake_model = model.as_fake();
2308
2309 let events_1 = thread
2310 .update(cx, |thread, cx| {
2311 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2312 })
2313 .unwrap();
2314 cx.run_until_parked();
2315 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2316 fake_model
2317 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2318 fake_model.end_last_completion_stream();
2319 let events_1 = events_1.collect::<Vec<_>>().await;
2320
2321 let events_2 = thread
2322 .update(cx, |thread, cx| {
2323 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2324 })
2325 .unwrap();
2326 cx.run_until_parked();
2327 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2328 fake_model
2329 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2330 fake_model.end_last_completion_stream();
2331 let events_2 = events_2.collect::<Vec<_>>().await;
2332
2333 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2334 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2335}
2336
2337#[gpui::test]
2338async fn test_refusal(cx: &mut TestAppContext) {
2339 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2340 let fake_model = model.as_fake();
2341
2342 let events = thread
2343 .update(cx, |thread, cx| {
2344 thread.send(UserMessageId::new(), ["Hello"], cx)
2345 })
2346 .unwrap();
2347 cx.run_until_parked();
2348 thread.read_with(cx, |thread, _| {
2349 assert_eq!(
2350 thread.to_markdown(),
2351 indoc! {"
2352 ## User
2353
2354 Hello
2355 "}
2356 );
2357 });
2358
2359 fake_model.send_last_completion_stream_text_chunk("Hey!");
2360 cx.run_until_parked();
2361 thread.read_with(cx, |thread, _| {
2362 assert_eq!(
2363 thread.to_markdown(),
2364 indoc! {"
2365 ## User
2366
2367 Hello
2368
2369 ## Assistant
2370
2371 Hey!
2372 "}
2373 );
2374 });
2375
2376 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2377 fake_model
2378 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2379 let events = events.collect::<Vec<_>>().await;
2380 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2381 thread.read_with(cx, |thread, _| {
2382 assert_eq!(thread.to_markdown(), "");
2383 });
2384}
2385
2386#[gpui::test]
2387async fn test_truncate_first_message(cx: &mut TestAppContext) {
2388 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2389 let fake_model = model.as_fake();
2390
2391 let message_id = UserMessageId::new();
2392 thread
2393 .update(cx, |thread, cx| {
2394 thread.send(message_id.clone(), ["Hello"], cx)
2395 })
2396 .unwrap();
2397 cx.run_until_parked();
2398 thread.read_with(cx, |thread, _| {
2399 assert_eq!(
2400 thread.to_markdown(),
2401 indoc! {"
2402 ## User
2403
2404 Hello
2405 "}
2406 );
2407 assert_eq!(thread.latest_token_usage(), None);
2408 });
2409
2410 fake_model.send_last_completion_stream_text_chunk("Hey!");
2411 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2412 language_model::TokenUsage {
2413 input_tokens: 32_000,
2414 output_tokens: 16_000,
2415 cache_creation_input_tokens: 0,
2416 cache_read_input_tokens: 0,
2417 },
2418 ));
2419 cx.run_until_parked();
2420 thread.read_with(cx, |thread, _| {
2421 assert_eq!(
2422 thread.to_markdown(),
2423 indoc! {"
2424 ## User
2425
2426 Hello
2427
2428 ## Assistant
2429
2430 Hey!
2431 "}
2432 );
2433 assert_eq!(
2434 thread.latest_token_usage(),
2435 Some(acp_thread::TokenUsage {
2436 used_tokens: 32_000 + 16_000,
2437 max_tokens: 1_000_000,
2438 input_tokens: 32_000,
2439 output_tokens: 16_000,
2440 })
2441 );
2442 });
2443
2444 thread
2445 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2446 .unwrap();
2447 cx.run_until_parked();
2448 thread.read_with(cx, |thread, _| {
2449 assert_eq!(thread.to_markdown(), "");
2450 assert_eq!(thread.latest_token_usage(), None);
2451 });
2452
2453 // Ensure we can still send a new message after truncation.
2454 thread
2455 .update(cx, |thread, cx| {
2456 thread.send(UserMessageId::new(), ["Hi"], cx)
2457 })
2458 .unwrap();
2459 thread.update(cx, |thread, _cx| {
2460 assert_eq!(
2461 thread.to_markdown(),
2462 indoc! {"
2463 ## User
2464
2465 Hi
2466 "}
2467 );
2468 });
2469 cx.run_until_parked();
2470 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2471 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2472 language_model::TokenUsage {
2473 input_tokens: 40_000,
2474 output_tokens: 20_000,
2475 cache_creation_input_tokens: 0,
2476 cache_read_input_tokens: 0,
2477 },
2478 ));
2479 cx.run_until_parked();
2480 thread.read_with(cx, |thread, _| {
2481 assert_eq!(
2482 thread.to_markdown(),
2483 indoc! {"
2484 ## User
2485
2486 Hi
2487
2488 ## Assistant
2489
2490 Ahoy!
2491 "}
2492 );
2493
2494 assert_eq!(
2495 thread.latest_token_usage(),
2496 Some(acp_thread::TokenUsage {
2497 used_tokens: 40_000 + 20_000,
2498 max_tokens: 1_000_000,
2499 input_tokens: 40_000,
2500 output_tokens: 20_000,
2501 })
2502 );
2503 });
2504}
2505
2506#[gpui::test]
2507async fn test_truncate_second_message(cx: &mut TestAppContext) {
2508 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2509 let fake_model = model.as_fake();
2510
2511 thread
2512 .update(cx, |thread, cx| {
2513 thread.send(UserMessageId::new(), ["Message 1"], cx)
2514 })
2515 .unwrap();
2516 cx.run_until_parked();
2517 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2518 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2519 language_model::TokenUsage {
2520 input_tokens: 32_000,
2521 output_tokens: 16_000,
2522 cache_creation_input_tokens: 0,
2523 cache_read_input_tokens: 0,
2524 },
2525 ));
2526 fake_model.end_last_completion_stream();
2527 cx.run_until_parked();
2528
2529 let assert_first_message_state = |cx: &mut TestAppContext| {
2530 thread.clone().read_with(cx, |thread, _| {
2531 assert_eq!(
2532 thread.to_markdown(),
2533 indoc! {"
2534 ## User
2535
2536 Message 1
2537
2538 ## Assistant
2539
2540 Message 1 response
2541 "}
2542 );
2543
2544 assert_eq!(
2545 thread.latest_token_usage(),
2546 Some(acp_thread::TokenUsage {
2547 used_tokens: 32_000 + 16_000,
2548 max_tokens: 1_000_000,
2549 input_tokens: 32_000,
2550 output_tokens: 16_000,
2551 })
2552 );
2553 });
2554 };
2555
2556 assert_first_message_state(cx);
2557
2558 let second_message_id = UserMessageId::new();
2559 thread
2560 .update(cx, |thread, cx| {
2561 thread.send(second_message_id.clone(), ["Message 2"], cx)
2562 })
2563 .unwrap();
2564 cx.run_until_parked();
2565
2566 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2567 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2568 language_model::TokenUsage {
2569 input_tokens: 40_000,
2570 output_tokens: 20_000,
2571 cache_creation_input_tokens: 0,
2572 cache_read_input_tokens: 0,
2573 },
2574 ));
2575 fake_model.end_last_completion_stream();
2576 cx.run_until_parked();
2577
2578 thread.read_with(cx, |thread, _| {
2579 assert_eq!(
2580 thread.to_markdown(),
2581 indoc! {"
2582 ## User
2583
2584 Message 1
2585
2586 ## Assistant
2587
2588 Message 1 response
2589
2590 ## User
2591
2592 Message 2
2593
2594 ## Assistant
2595
2596 Message 2 response
2597 "}
2598 );
2599
2600 assert_eq!(
2601 thread.latest_token_usage(),
2602 Some(acp_thread::TokenUsage {
2603 used_tokens: 40_000 + 20_000,
2604 max_tokens: 1_000_000,
2605 input_tokens: 40_000,
2606 output_tokens: 20_000,
2607 })
2608 );
2609 });
2610
2611 thread
2612 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2613 .unwrap();
2614 cx.run_until_parked();
2615
2616 assert_first_message_state(cx);
2617}
2618
2619#[gpui::test]
2620async fn test_title_generation(cx: &mut TestAppContext) {
2621 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2622 let fake_model = model.as_fake();
2623
2624 let summary_model = Arc::new(FakeLanguageModel::default());
2625 thread.update(cx, |thread, cx| {
2626 thread.set_summarization_model(Some(summary_model.clone()), cx)
2627 });
2628
2629 let send = thread
2630 .update(cx, |thread, cx| {
2631 thread.send(UserMessageId::new(), ["Hello"], cx)
2632 })
2633 .unwrap();
2634 cx.run_until_parked();
2635
2636 fake_model.send_last_completion_stream_text_chunk("Hey!");
2637 fake_model.end_last_completion_stream();
2638 cx.run_until_parked();
2639 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2640
2641 // Ensure the summary model has been invoked to generate a title.
2642 summary_model.send_last_completion_stream_text_chunk("Hello ");
2643 summary_model.send_last_completion_stream_text_chunk("world\nG");
2644 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2645 summary_model.end_last_completion_stream();
2646 send.collect::<Vec<_>>().await;
2647 cx.run_until_parked();
2648 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2649
2650 // Send another message, ensuring no title is generated this time.
2651 let send = thread
2652 .update(cx, |thread, cx| {
2653 thread.send(UserMessageId::new(), ["Hello again"], cx)
2654 })
2655 .unwrap();
2656 cx.run_until_parked();
2657 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2658 fake_model.end_last_completion_stream();
2659 cx.run_until_parked();
2660 assert_eq!(summary_model.pending_completions(), Vec::new());
2661 send.collect::<Vec<_>>().await;
2662 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2663}
2664
2665#[gpui::test]
2666async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2667 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2668 let fake_model = model.as_fake();
2669
2670 let _events = thread
2671 .update(cx, |thread, cx| {
2672 thread.add_tool(ToolRequiringPermission);
2673 thread.add_tool(EchoTool);
2674 thread.send(UserMessageId::new(), ["Hey!"], cx)
2675 })
2676 .unwrap();
2677 cx.run_until_parked();
2678
2679 let permission_tool_use = LanguageModelToolUse {
2680 id: "tool_id_1".into(),
2681 name: ToolRequiringPermission::name().into(),
2682 raw_input: "{}".into(),
2683 input: json!({}),
2684 is_input_complete: true,
2685 thought_signature: None,
2686 };
2687 let echo_tool_use = LanguageModelToolUse {
2688 id: "tool_id_2".into(),
2689 name: EchoTool::name().into(),
2690 raw_input: json!({"text": "test"}).to_string(),
2691 input: json!({"text": "test"}),
2692 is_input_complete: true,
2693 thought_signature: None,
2694 };
2695 fake_model.send_last_completion_stream_text_chunk("Hi!");
2696 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2697 permission_tool_use,
2698 ));
2699 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2700 echo_tool_use.clone(),
2701 ));
2702 fake_model.end_last_completion_stream();
2703 cx.run_until_parked();
2704
2705 // Ensure pending tools are skipped when building a request.
2706 let request = thread
2707 .read_with(cx, |thread, cx| {
2708 thread.build_completion_request(CompletionIntent::EditFile, cx)
2709 })
2710 .unwrap();
2711 assert_eq!(
2712 request.messages[1..],
2713 vec![
2714 LanguageModelRequestMessage {
2715 role: Role::User,
2716 content: vec!["Hey!".into()],
2717 cache: true,
2718 reasoning_details: None,
2719 },
2720 LanguageModelRequestMessage {
2721 role: Role::Assistant,
2722 content: vec![
2723 MessageContent::Text("Hi!".into()),
2724 MessageContent::ToolUse(echo_tool_use.clone())
2725 ],
2726 cache: false,
2727 reasoning_details: None,
2728 },
2729 LanguageModelRequestMessage {
2730 role: Role::User,
2731 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2732 tool_use_id: echo_tool_use.id.clone(),
2733 tool_name: echo_tool_use.name,
2734 is_error: false,
2735 content: "test".into(),
2736 output: Some("test".into())
2737 })],
2738 cache: false,
2739 reasoning_details: None,
2740 },
2741 ],
2742 );
2743}
2744
2745#[gpui::test]
2746async fn test_agent_connection(cx: &mut TestAppContext) {
2747 cx.update(settings::init);
2748 let templates = Templates::new();
2749
2750 // Initialize language model system with test provider
2751 cx.update(|cx| {
2752 gpui_tokio::init(cx);
2753
2754 let http_client = FakeHttpClient::with_404_response();
2755 let clock = Arc::new(clock::FakeSystemClock::new());
2756 let client = Client::new(clock, http_client, cx);
2757 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2758 language_model::init(client.clone(), cx);
2759 language_models::init(user_store, client.clone(), cx);
2760 LanguageModelRegistry::test(cx);
2761 });
2762 cx.executor().forbid_parking();
2763
2764 // Create a project for new_thread
2765 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2766 fake_fs.insert_tree(path!("/test"), json!({})).await;
2767 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2768 let cwd = Path::new("/test");
2769 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2770
2771 // Create agent and connection
2772 let agent = NativeAgent::new(
2773 project.clone(),
2774 thread_store,
2775 templates.clone(),
2776 None,
2777 fake_fs.clone(),
2778 &mut cx.to_async(),
2779 )
2780 .await
2781 .unwrap();
2782 let connection = NativeAgentConnection(agent.clone());
2783
2784 // Create a thread using new_thread
2785 let connection_rc = Rc::new(connection.clone());
2786 let acp_thread = cx
2787 .update(|cx| connection_rc.new_thread(project, cwd, cx))
2788 .await
2789 .expect("new_thread should succeed");
2790
2791 // Get the session_id from the AcpThread
2792 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2793
2794 // Test model_selector returns Some
2795 let selector_opt = connection.model_selector(&session_id);
2796 assert!(
2797 selector_opt.is_some(),
2798 "agent should always support ModelSelector"
2799 );
2800 let selector = selector_opt.unwrap();
2801
2802 // Test list_models
2803 let listed_models = cx
2804 .update(|cx| selector.list_models(cx))
2805 .await
2806 .expect("list_models should succeed");
2807 let AgentModelList::Grouped(listed_models) = listed_models else {
2808 panic!("Unexpected model list type");
2809 };
2810 assert!(!listed_models.is_empty(), "should have at least one model");
2811 assert_eq!(
2812 listed_models[&AgentModelGroupName("Fake".into())][0]
2813 .id
2814 .0
2815 .as_ref(),
2816 "fake/fake"
2817 );
2818
2819 // Test selected_model returns the default
2820 let model = cx
2821 .update(|cx| selector.selected_model(cx))
2822 .await
2823 .expect("selected_model should succeed");
2824 let model = cx
2825 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2826 .unwrap();
2827 let model = model.as_fake();
2828 assert_eq!(model.id().0, "fake", "should return default model");
2829
2830 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2831 cx.run_until_parked();
2832 model.send_last_completion_stream_text_chunk("def");
2833 cx.run_until_parked();
2834 acp_thread.read_with(cx, |thread, cx| {
2835 assert_eq!(
2836 thread.to_markdown(cx),
2837 indoc! {"
2838 ## User
2839
2840 abc
2841
2842 ## Assistant
2843
2844 def
2845
2846 "}
2847 )
2848 });
2849
2850 // Test cancel
2851 cx.update(|cx| connection.cancel(&session_id, cx));
2852 request.await.expect("prompt should fail gracefully");
2853
2854 // Ensure that dropping the ACP thread causes the native thread to be
2855 // dropped as well.
2856 cx.update(|_| drop(acp_thread));
2857 let result = cx
2858 .update(|cx| {
2859 connection.prompt(
2860 Some(acp_thread::UserMessageId::new()),
2861 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2862 cx,
2863 )
2864 })
2865 .await;
2866 assert_eq!(
2867 result.as_ref().unwrap_err().to_string(),
2868 "Session not found",
2869 "unexpected result: {:?}",
2870 result
2871 );
2872}
2873
2874#[gpui::test]
2875async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2876 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2877 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2878 let fake_model = model.as_fake();
2879
2880 let mut events = thread
2881 .update(cx, |thread, cx| {
2882 thread.send(UserMessageId::new(), ["Think"], cx)
2883 })
2884 .unwrap();
2885 cx.run_until_parked();
2886
2887 // Simulate streaming partial input.
2888 let input = json!({});
2889 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2890 LanguageModelToolUse {
2891 id: "1".into(),
2892 name: ThinkingTool::name().into(),
2893 raw_input: input.to_string(),
2894 input,
2895 is_input_complete: false,
2896 thought_signature: None,
2897 },
2898 ));
2899
2900 // Input streaming completed
2901 let input = json!({ "content": "Thinking hard!" });
2902 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2903 LanguageModelToolUse {
2904 id: "1".into(),
2905 name: "thinking".into(),
2906 raw_input: input.to_string(),
2907 input,
2908 is_input_complete: true,
2909 thought_signature: None,
2910 },
2911 ));
2912 fake_model.end_last_completion_stream();
2913 cx.run_until_parked();
2914
2915 let tool_call = expect_tool_call(&mut events).await;
2916 assert_eq!(
2917 tool_call,
2918 acp::ToolCall::new("1", "Thinking")
2919 .kind(acp::ToolKind::Think)
2920 .raw_input(json!({}))
2921 .meta(acp::Meta::from_iter([(
2922 "tool_name".into(),
2923 "thinking".into()
2924 )]))
2925 );
2926 let update = expect_tool_call_update_fields(&mut events).await;
2927 assert_eq!(
2928 update,
2929 acp::ToolCallUpdate::new(
2930 "1",
2931 acp::ToolCallUpdateFields::new()
2932 .title("Thinking")
2933 .kind(acp::ToolKind::Think)
2934 .raw_input(json!({ "content": "Thinking hard!"}))
2935 )
2936 );
2937 let update = expect_tool_call_update_fields(&mut events).await;
2938 assert_eq!(
2939 update,
2940 acp::ToolCallUpdate::new(
2941 "1",
2942 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
2943 )
2944 );
2945 let update = expect_tool_call_update_fields(&mut events).await;
2946 assert_eq!(
2947 update,
2948 acp::ToolCallUpdate::new(
2949 "1",
2950 acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
2951 )
2952 );
2953 let update = expect_tool_call_update_fields(&mut events).await;
2954 assert_eq!(
2955 update,
2956 acp::ToolCallUpdate::new(
2957 "1",
2958 acp::ToolCallUpdateFields::new()
2959 .status(acp::ToolCallStatus::Completed)
2960 .raw_output("Finished thinking.")
2961 )
2962 );
2963}
2964
2965#[gpui::test]
2966async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2967 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2968 let fake_model = model.as_fake();
2969
2970 let mut events = thread
2971 .update(cx, |thread, cx| {
2972 thread.send(UserMessageId::new(), ["Hello!"], cx)
2973 })
2974 .unwrap();
2975 cx.run_until_parked();
2976
2977 fake_model.send_last_completion_stream_text_chunk("Hey!");
2978 fake_model.end_last_completion_stream();
2979
2980 let mut retry_events = Vec::new();
2981 while let Some(Ok(event)) = events.next().await {
2982 match event {
2983 ThreadEvent::Retry(retry_status) => {
2984 retry_events.push(retry_status);
2985 }
2986 ThreadEvent::Stop(..) => break,
2987 _ => {}
2988 }
2989 }
2990
2991 assert_eq!(retry_events.len(), 0);
2992 thread.read_with(cx, |thread, _cx| {
2993 assert_eq!(
2994 thread.to_markdown(),
2995 indoc! {"
2996 ## User
2997
2998 Hello!
2999
3000 ## Assistant
3001
3002 Hey!
3003 "}
3004 )
3005 });
3006}
3007
3008#[gpui::test]
3009async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3010 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3011 let fake_model = model.as_fake();
3012
3013 let mut events = thread
3014 .update(cx, |thread, cx| {
3015 thread.send(UserMessageId::new(), ["Hello!"], cx)
3016 })
3017 .unwrap();
3018 cx.run_until_parked();
3019
3020 fake_model.send_last_completion_stream_text_chunk("Hey,");
3021 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3022 provider: LanguageModelProviderName::new("Anthropic"),
3023 retry_after: Some(Duration::from_secs(3)),
3024 });
3025 fake_model.end_last_completion_stream();
3026
3027 cx.executor().advance_clock(Duration::from_secs(3));
3028 cx.run_until_parked();
3029
3030 fake_model.send_last_completion_stream_text_chunk("there!");
3031 fake_model.end_last_completion_stream();
3032 cx.run_until_parked();
3033
3034 let mut retry_events = Vec::new();
3035 while let Some(Ok(event)) = events.next().await {
3036 match event {
3037 ThreadEvent::Retry(retry_status) => {
3038 retry_events.push(retry_status);
3039 }
3040 ThreadEvent::Stop(..) => break,
3041 _ => {}
3042 }
3043 }
3044
3045 assert_eq!(retry_events.len(), 1);
3046 assert!(matches!(
3047 retry_events[0],
3048 acp_thread::RetryStatus { attempt: 1, .. }
3049 ));
3050 thread.read_with(cx, |thread, _cx| {
3051 assert_eq!(
3052 thread.to_markdown(),
3053 indoc! {"
3054 ## User
3055
3056 Hello!
3057
3058 ## Assistant
3059
3060 Hey,
3061
3062 [resume]
3063
3064 ## Assistant
3065
3066 there!
3067 "}
3068 )
3069 });
3070}
3071
3072#[gpui::test]
3073async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3074 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3075 let fake_model = model.as_fake();
3076
3077 let events = thread
3078 .update(cx, |thread, cx| {
3079 thread.add_tool(EchoTool);
3080 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3081 })
3082 .unwrap();
3083 cx.run_until_parked();
3084
3085 let tool_use_1 = LanguageModelToolUse {
3086 id: "tool_1".into(),
3087 name: EchoTool::name().into(),
3088 raw_input: json!({"text": "test"}).to_string(),
3089 input: json!({"text": "test"}),
3090 is_input_complete: true,
3091 thought_signature: None,
3092 };
3093 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3094 tool_use_1.clone(),
3095 ));
3096 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3097 provider: LanguageModelProviderName::new("Anthropic"),
3098 retry_after: Some(Duration::from_secs(3)),
3099 });
3100 fake_model.end_last_completion_stream();
3101
3102 cx.executor().advance_clock(Duration::from_secs(3));
3103 let completion = fake_model.pending_completions().pop().unwrap();
3104 assert_eq!(
3105 completion.messages[1..],
3106 vec![
3107 LanguageModelRequestMessage {
3108 role: Role::User,
3109 content: vec!["Call the echo tool!".into()],
3110 cache: false,
3111 reasoning_details: None,
3112 },
3113 LanguageModelRequestMessage {
3114 role: Role::Assistant,
3115 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3116 cache: false,
3117 reasoning_details: None,
3118 },
3119 LanguageModelRequestMessage {
3120 role: Role::User,
3121 content: vec![language_model::MessageContent::ToolResult(
3122 LanguageModelToolResult {
3123 tool_use_id: tool_use_1.id.clone(),
3124 tool_name: tool_use_1.name.clone(),
3125 is_error: false,
3126 content: "test".into(),
3127 output: Some("test".into())
3128 }
3129 )],
3130 cache: true,
3131 reasoning_details: None,
3132 },
3133 ]
3134 );
3135
3136 fake_model.send_last_completion_stream_text_chunk("Done");
3137 fake_model.end_last_completion_stream();
3138 cx.run_until_parked();
3139 events.collect::<Vec<_>>().await;
3140 thread.read_with(cx, |thread, _cx| {
3141 assert_eq!(
3142 thread.last_message(),
3143 Some(Message::Agent(AgentMessage {
3144 content: vec![AgentMessageContent::Text("Done".into())],
3145 tool_results: IndexMap::default(),
3146 reasoning_details: None,
3147 }))
3148 );
3149 })
3150}
3151
3152#[gpui::test]
3153async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3154 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3155 let fake_model = model.as_fake();
3156
3157 let mut events = thread
3158 .update(cx, |thread, cx| {
3159 thread.send(UserMessageId::new(), ["Hello!"], cx)
3160 })
3161 .unwrap();
3162 cx.run_until_parked();
3163
3164 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3165 fake_model.send_last_completion_stream_error(
3166 LanguageModelCompletionError::ServerOverloaded {
3167 provider: LanguageModelProviderName::new("Anthropic"),
3168 retry_after: Some(Duration::from_secs(3)),
3169 },
3170 );
3171 fake_model.end_last_completion_stream();
3172 cx.executor().advance_clock(Duration::from_secs(3));
3173 cx.run_until_parked();
3174 }
3175
3176 let mut errors = Vec::new();
3177 let mut retry_events = Vec::new();
3178 while let Some(event) = events.next().await {
3179 match event {
3180 Ok(ThreadEvent::Retry(retry_status)) => {
3181 retry_events.push(retry_status);
3182 }
3183 Ok(ThreadEvent::Stop(..)) => break,
3184 Err(error) => errors.push(error),
3185 _ => {}
3186 }
3187 }
3188
3189 assert_eq!(
3190 retry_events.len(),
3191 crate::thread::MAX_RETRY_ATTEMPTS as usize
3192 );
3193 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3194 assert_eq!(retry_events[i].attempt, i + 1);
3195 }
3196 assert_eq!(errors.len(), 1);
3197 let error = errors[0]
3198 .downcast_ref::<LanguageModelCompletionError>()
3199 .unwrap();
3200 assert!(matches!(
3201 error,
3202 LanguageModelCompletionError::ServerOverloaded { .. }
3203 ));
3204}
3205
3206/// Filters out the stop events for asserting against in tests
3207fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3208 result_events
3209 .into_iter()
3210 .filter_map(|event| match event.unwrap() {
3211 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3212 _ => None,
3213 })
3214 .collect()
3215}
3216
3217struct ThreadTest {
3218 model: Arc<dyn LanguageModel>,
3219 thread: Entity<Thread>,
3220 project_context: Entity<ProjectContext>,
3221 context_server_store: Entity<ContextServerStore>,
3222 fs: Arc<FakeFs>,
3223}
3224
3225enum TestModel {
3226 Sonnet4,
3227 Fake,
3228}
3229
3230impl TestModel {
3231 fn id(&self) -> LanguageModelId {
3232 match self {
3233 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3234 TestModel::Fake => unreachable!(),
3235 }
3236 }
3237}
3238
3239async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3240 cx.executor().allow_parking();
3241
3242 let fs = FakeFs::new(cx.background_executor.clone());
3243 fs.create_dir(paths::settings_file().parent().unwrap())
3244 .await
3245 .unwrap();
3246 fs.insert_file(
3247 paths::settings_file(),
3248 json!({
3249 "agent": {
3250 "default_profile": "test-profile",
3251 "profiles": {
3252 "test-profile": {
3253 "name": "Test Profile",
3254 "tools": {
3255 EchoTool::name(): true,
3256 DelayTool::name(): true,
3257 WordListTool::name(): true,
3258 ToolRequiringPermission::name(): true,
3259 InfiniteTool::name(): true,
3260 CancellationAwareTool::name(): true,
3261 ThinkingTool::name(): true,
3262 "terminal": true,
3263 }
3264 }
3265 }
3266 }
3267 })
3268 .to_string()
3269 .into_bytes(),
3270 )
3271 .await;
3272
3273 cx.update(|cx| {
3274 settings::init(cx);
3275
3276 match model {
3277 TestModel::Fake => {}
3278 TestModel::Sonnet4 => {
3279 gpui_tokio::init(cx);
3280 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3281 cx.set_http_client(Arc::new(http_client));
3282 let client = Client::production(cx);
3283 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3284 language_model::init(client.clone(), cx);
3285 language_models::init(user_store, client.clone(), cx);
3286 }
3287 };
3288
3289 watch_settings(fs.clone(), cx);
3290 });
3291
3292 let templates = Templates::new();
3293
3294 fs.insert_tree(path!("/test"), json!({})).await;
3295 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3296
3297 let model = cx
3298 .update(|cx| {
3299 if let TestModel::Fake = model {
3300 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3301 } else {
3302 let model_id = model.id();
3303 let models = LanguageModelRegistry::read_global(cx);
3304 let model = models
3305 .available_models(cx)
3306 .find(|model| model.id() == model_id)
3307 .unwrap();
3308
3309 let provider = models.provider(&model.provider_id()).unwrap();
3310 let authenticated = provider.authenticate(cx);
3311
3312 cx.spawn(async move |_cx| {
3313 authenticated.await.unwrap();
3314 model
3315 })
3316 }
3317 })
3318 .await;
3319
3320 let project_context = cx.new(|_cx| ProjectContext::default());
3321 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3322 let context_server_registry =
3323 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3324 let thread = cx.new(|cx| {
3325 Thread::new(
3326 project,
3327 project_context.clone(),
3328 context_server_registry,
3329 templates,
3330 Some(model.clone()),
3331 cx,
3332 )
3333 });
3334 ThreadTest {
3335 model,
3336 thread,
3337 project_context,
3338 context_server_store,
3339 fs,
3340 }
3341}
3342
3343#[cfg(test)]
3344#[ctor::ctor]
3345fn init_logger() {
3346 if std::env::var("RUST_LOG").is_ok() {
3347 env_logger::init();
3348 }
3349}
3350
3351fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3352 let fs = fs.clone();
3353 cx.spawn({
3354 async move |cx| {
3355 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3356 cx.background_executor(),
3357 fs,
3358 paths::settings_file().clone(),
3359 );
3360 let _watcher_task = watcher_task;
3361
3362 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3363 cx.update(|cx| {
3364 SettingsStore::update_global(cx, |settings, cx| {
3365 settings.set_user_settings(&new_settings_content, cx)
3366 })
3367 })
3368 .ok();
3369 }
3370 }
3371 })
3372 .detach();
3373}
3374
3375fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3376 completion
3377 .tools
3378 .iter()
3379 .map(|tool| tool.name.clone())
3380 .collect()
3381}
3382
3383fn setup_context_server(
3384 name: &'static str,
3385 tools: Vec<context_server::types::Tool>,
3386 context_server_store: &Entity<ContextServerStore>,
3387 cx: &mut TestAppContext,
3388) -> mpsc::UnboundedReceiver<(
3389 context_server::types::CallToolParams,
3390 oneshot::Sender<context_server::types::CallToolResponse>,
3391)> {
3392 cx.update(|cx| {
3393 let mut settings = ProjectSettings::get_global(cx).clone();
3394 settings.context_servers.insert(
3395 name.into(),
3396 project::project_settings::ContextServerSettings::Stdio {
3397 enabled: true,
3398 remote: false,
3399 command: ContextServerCommand {
3400 path: "somebinary".into(),
3401 args: Vec::new(),
3402 env: None,
3403 timeout: None,
3404 },
3405 },
3406 );
3407 ProjectSettings::override_global(settings, cx);
3408 });
3409
3410 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3411 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3412 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3413 context_server::types::InitializeResponse {
3414 protocol_version: context_server::types::ProtocolVersion(
3415 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3416 ),
3417 server_info: context_server::types::Implementation {
3418 name: name.into(),
3419 version: "1.0.0".to_string(),
3420 },
3421 capabilities: context_server::types::ServerCapabilities {
3422 tools: Some(context_server::types::ToolsCapabilities {
3423 list_changed: Some(true),
3424 }),
3425 ..Default::default()
3426 },
3427 meta: None,
3428 }
3429 })
3430 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3431 let tools = tools.clone();
3432 async move {
3433 context_server::types::ListToolsResponse {
3434 tools,
3435 next_cursor: None,
3436 meta: None,
3437 }
3438 }
3439 })
3440 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3441 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3442 async move {
3443 let (response_tx, response_rx) = oneshot::channel();
3444 mcp_tool_calls_tx
3445 .unbounded_send((params, response_tx))
3446 .unwrap();
3447 response_rx.await.unwrap()
3448 }
3449 });
3450 context_server_store.update(cx, |store, cx| {
3451 store.start_server(
3452 Arc::new(ContextServer::new(
3453 ContextServerId(name.into()),
3454 Arc::new(fake_transport),
3455 )),
3456 cx,
3457 );
3458 });
3459 cx.run_until_parked();
3460 mcp_tool_calls_rx
3461}
3462
3463#[gpui::test]
3464async fn test_tokens_before_message(cx: &mut TestAppContext) {
3465 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3466 let fake_model = model.as_fake();
3467
3468 // First message
3469 let message_1_id = UserMessageId::new();
3470 thread
3471 .update(cx, |thread, cx| {
3472 thread.send(message_1_id.clone(), ["First message"], cx)
3473 })
3474 .unwrap();
3475 cx.run_until_parked();
3476
3477 // Before any response, tokens_before_message should return None for first message
3478 thread.read_with(cx, |thread, _| {
3479 assert_eq!(
3480 thread.tokens_before_message(&message_1_id),
3481 None,
3482 "First message should have no tokens before it"
3483 );
3484 });
3485
3486 // Complete first message with usage
3487 fake_model.send_last_completion_stream_text_chunk("Response 1");
3488 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3489 language_model::TokenUsage {
3490 input_tokens: 100,
3491 output_tokens: 50,
3492 cache_creation_input_tokens: 0,
3493 cache_read_input_tokens: 0,
3494 },
3495 ));
3496 fake_model.end_last_completion_stream();
3497 cx.run_until_parked();
3498
3499 // First message still has no tokens before it
3500 thread.read_with(cx, |thread, _| {
3501 assert_eq!(
3502 thread.tokens_before_message(&message_1_id),
3503 None,
3504 "First message should still have no tokens before it after response"
3505 );
3506 });
3507
3508 // Second message
3509 let message_2_id = UserMessageId::new();
3510 thread
3511 .update(cx, |thread, cx| {
3512 thread.send(message_2_id.clone(), ["Second message"], cx)
3513 })
3514 .unwrap();
3515 cx.run_until_parked();
3516
3517 // Second message should have first message's input tokens before it
3518 thread.read_with(cx, |thread, _| {
3519 assert_eq!(
3520 thread.tokens_before_message(&message_2_id),
3521 Some(100),
3522 "Second message should have 100 tokens before it (from first request)"
3523 );
3524 });
3525
3526 // Complete second message
3527 fake_model.send_last_completion_stream_text_chunk("Response 2");
3528 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3529 language_model::TokenUsage {
3530 input_tokens: 250, // Total for this request (includes previous context)
3531 output_tokens: 75,
3532 cache_creation_input_tokens: 0,
3533 cache_read_input_tokens: 0,
3534 },
3535 ));
3536 fake_model.end_last_completion_stream();
3537 cx.run_until_parked();
3538
3539 // Third message
3540 let message_3_id = UserMessageId::new();
3541 thread
3542 .update(cx, |thread, cx| {
3543 thread.send(message_3_id.clone(), ["Third message"], cx)
3544 })
3545 .unwrap();
3546 cx.run_until_parked();
3547
3548 // Third message should have second message's input tokens (250) before it
3549 thread.read_with(cx, |thread, _| {
3550 assert_eq!(
3551 thread.tokens_before_message(&message_3_id),
3552 Some(250),
3553 "Third message should have 250 tokens before it (from second request)"
3554 );
3555 // Second message should still have 100
3556 assert_eq!(
3557 thread.tokens_before_message(&message_2_id),
3558 Some(100),
3559 "Second message should still have 100 tokens before it"
3560 );
3561 // First message still has none
3562 assert_eq!(
3563 thread.tokens_before_message(&message_1_id),
3564 None,
3565 "First message should still have no tokens before it"
3566 );
3567 });
3568}
3569
3570#[gpui::test]
3571async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3572 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3573 let fake_model = model.as_fake();
3574
3575 // Set up three messages with responses
3576 let message_1_id = UserMessageId::new();
3577 thread
3578 .update(cx, |thread, cx| {
3579 thread.send(message_1_id.clone(), ["Message 1"], cx)
3580 })
3581 .unwrap();
3582 cx.run_until_parked();
3583 fake_model.send_last_completion_stream_text_chunk("Response 1");
3584 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3585 language_model::TokenUsage {
3586 input_tokens: 100,
3587 output_tokens: 50,
3588 cache_creation_input_tokens: 0,
3589 cache_read_input_tokens: 0,
3590 },
3591 ));
3592 fake_model.end_last_completion_stream();
3593 cx.run_until_parked();
3594
3595 let message_2_id = UserMessageId::new();
3596 thread
3597 .update(cx, |thread, cx| {
3598 thread.send(message_2_id.clone(), ["Message 2"], cx)
3599 })
3600 .unwrap();
3601 cx.run_until_parked();
3602 fake_model.send_last_completion_stream_text_chunk("Response 2");
3603 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3604 language_model::TokenUsage {
3605 input_tokens: 250,
3606 output_tokens: 75,
3607 cache_creation_input_tokens: 0,
3608 cache_read_input_tokens: 0,
3609 },
3610 ));
3611 fake_model.end_last_completion_stream();
3612 cx.run_until_parked();
3613
3614 // Verify initial state
3615 thread.read_with(cx, |thread, _| {
3616 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3617 });
3618
3619 // Truncate at message 2 (removes message 2 and everything after)
3620 thread
3621 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3622 .unwrap();
3623 cx.run_until_parked();
3624
3625 // After truncation, message_2_id no longer exists, so lookup should return None
3626 thread.read_with(cx, |thread, _| {
3627 assert_eq!(
3628 thread.tokens_before_message(&message_2_id),
3629 None,
3630 "After truncation, message 2 no longer exists"
3631 );
3632 // Message 1 still exists but has no tokens before it
3633 assert_eq!(
3634 thread.tokens_before_message(&message_1_id),
3635 None,
3636 "First message still has no tokens before it"
3637 );
3638 });
3639}
3640
3641#[gpui::test]
3642async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3643 init_test(cx);
3644
3645 let fs = FakeFs::new(cx.executor());
3646 fs.insert_tree("/root", json!({})).await;
3647 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3648
3649 // Test 1: Deny rule blocks command
3650 {
3651 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3652 let environment = Rc::new(FakeThreadEnvironment {
3653 handle: handle.clone(),
3654 });
3655
3656 cx.update(|cx| {
3657 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3658 settings.tool_permissions.tools.insert(
3659 "terminal".into(),
3660 agent_settings::ToolRules {
3661 default_mode: settings::ToolPermissionMode::Confirm,
3662 always_allow: vec![],
3663 always_deny: vec![
3664 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3665 ],
3666 always_confirm: vec![],
3667 invalid_patterns: vec![],
3668 },
3669 );
3670 agent_settings::AgentSettings::override_global(settings, cx);
3671 });
3672
3673 #[allow(clippy::arc_with_non_send_sync)]
3674 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3675 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3676
3677 let task = cx.update(|cx| {
3678 tool.run(
3679 crate::TerminalToolInput {
3680 command: "rm -rf /".to_string(),
3681 cd: ".".to_string(),
3682 timeout_ms: None,
3683 },
3684 event_stream,
3685 cx,
3686 )
3687 });
3688
3689 let result = task.await;
3690 assert!(
3691 result.is_err(),
3692 "expected command to be blocked by deny rule"
3693 );
3694 assert!(
3695 result.unwrap_err().to_string().contains("blocked"),
3696 "error should mention the command was blocked"
3697 );
3698 }
3699
3700 // Test 2: Allow rule skips confirmation (and overrides default_mode: Deny)
3701 {
3702 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3703 let environment = Rc::new(FakeThreadEnvironment {
3704 handle: handle.clone(),
3705 });
3706
3707 cx.update(|cx| {
3708 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3709 settings.always_allow_tool_actions = false;
3710 settings.tool_permissions.tools.insert(
3711 "terminal".into(),
3712 agent_settings::ToolRules {
3713 default_mode: settings::ToolPermissionMode::Deny,
3714 always_allow: vec![
3715 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
3716 ],
3717 always_deny: vec![],
3718 always_confirm: vec![],
3719 invalid_patterns: vec![],
3720 },
3721 );
3722 agent_settings::AgentSettings::override_global(settings, cx);
3723 });
3724
3725 #[allow(clippy::arc_with_non_send_sync)]
3726 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3727 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3728
3729 let task = cx.update(|cx| {
3730 tool.run(
3731 crate::TerminalToolInput {
3732 command: "echo hello".to_string(),
3733 cd: ".".to_string(),
3734 timeout_ms: None,
3735 },
3736 event_stream,
3737 cx,
3738 )
3739 });
3740
3741 let update = rx.expect_update_fields().await;
3742 assert!(
3743 update.content.iter().any(|blocks| {
3744 blocks
3745 .iter()
3746 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
3747 }),
3748 "expected terminal content (allow rule should skip confirmation and override default deny)"
3749 );
3750
3751 let result = task.await;
3752 assert!(
3753 result.is_ok(),
3754 "expected command to succeed without confirmation"
3755 );
3756 }
3757
3758 // Test 3: always_allow_tool_actions=true overrides always_confirm patterns
3759 {
3760 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3761 let environment = Rc::new(FakeThreadEnvironment {
3762 handle: handle.clone(),
3763 });
3764
3765 cx.update(|cx| {
3766 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3767 settings.always_allow_tool_actions = true;
3768 settings.tool_permissions.tools.insert(
3769 "terminal".into(),
3770 agent_settings::ToolRules {
3771 default_mode: settings::ToolPermissionMode::Allow,
3772 always_allow: vec![],
3773 always_deny: vec![],
3774 always_confirm: vec![
3775 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
3776 ],
3777 invalid_patterns: vec![],
3778 },
3779 );
3780 agent_settings::AgentSettings::override_global(settings, cx);
3781 });
3782
3783 #[allow(clippy::arc_with_non_send_sync)]
3784 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3785 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3786
3787 let task = cx.update(|cx| {
3788 tool.run(
3789 crate::TerminalToolInput {
3790 command: "sudo rm file".to_string(),
3791 cd: ".".to_string(),
3792 timeout_ms: None,
3793 },
3794 event_stream,
3795 cx,
3796 )
3797 });
3798
3799 // With always_allow_tool_actions=true, confirm patterns are overridden
3800 task.await
3801 .expect("command should be allowed with always_allow_tool_actions=true");
3802 }
3803
3804 // Test 4: always_allow_tool_actions=true overrides default_mode: Deny
3805 {
3806 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3807 let environment = Rc::new(FakeThreadEnvironment {
3808 handle: handle.clone(),
3809 });
3810
3811 cx.update(|cx| {
3812 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3813 settings.always_allow_tool_actions = true;
3814 settings.tool_permissions.tools.insert(
3815 "terminal".into(),
3816 agent_settings::ToolRules {
3817 default_mode: settings::ToolPermissionMode::Deny,
3818 always_allow: vec![],
3819 always_deny: vec![],
3820 always_confirm: vec![],
3821 invalid_patterns: vec![],
3822 },
3823 );
3824 agent_settings::AgentSettings::override_global(settings, cx);
3825 });
3826
3827 #[allow(clippy::arc_with_non_send_sync)]
3828 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3829 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3830
3831 let task = cx.update(|cx| {
3832 tool.run(
3833 crate::TerminalToolInput {
3834 command: "echo hello".to_string(),
3835 cd: ".".to_string(),
3836 timeout_ms: None,
3837 },
3838 event_stream,
3839 cx,
3840 )
3841 });
3842
3843 // With always_allow_tool_actions=true, even default_mode: Deny is overridden
3844 task.await
3845 .expect("command should be allowed with always_allow_tool_actions=true");
3846 }
3847}
3848
3849#[gpui::test]
3850async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
3851 init_test(cx);
3852
3853 cx.update(|cx| {
3854 cx.update_flags(true, vec!["subagents".to_string()]);
3855 });
3856
3857 let fs = FakeFs::new(cx.executor());
3858 fs.insert_tree(path!("/test"), json!({})).await;
3859 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3860 let project_context = cx.new(|_cx| ProjectContext::default());
3861 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3862 let context_server_registry =
3863 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3864 let model = Arc::new(FakeLanguageModel::default());
3865
3866 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3867 let environment = Rc::new(FakeThreadEnvironment { handle });
3868
3869 let thread = cx.new(|cx| {
3870 let mut thread = Thread::new(
3871 project.clone(),
3872 project_context,
3873 context_server_registry,
3874 Templates::new(),
3875 Some(model),
3876 cx,
3877 );
3878 thread.add_default_tools(environment, cx);
3879 thread
3880 });
3881
3882 thread.read_with(cx, |thread, _| {
3883 assert!(
3884 thread.has_registered_tool("subagent"),
3885 "subagent tool should be present when feature flag is enabled"
3886 );
3887 });
3888}
3889
3890#[gpui::test]
3891async fn test_subagent_thread_inherits_parent_model(cx: &mut TestAppContext) {
3892 init_test(cx);
3893
3894 cx.update(|cx| {
3895 cx.update_flags(true, vec!["subagents".to_string()]);
3896 });
3897
3898 let fs = FakeFs::new(cx.executor());
3899 fs.insert_tree(path!("/test"), json!({})).await;
3900 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3901 let project_context = cx.new(|_cx| ProjectContext::default());
3902 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3903 let context_server_registry =
3904 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3905 let model = Arc::new(FakeLanguageModel::default());
3906
3907 let subagent_context = SubagentContext {
3908 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3909 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3910 depth: 1,
3911 summary_prompt: "Summarize".to_string(),
3912 context_low_prompt: "Context low".to_string(),
3913 };
3914
3915 let subagent = cx.new(|cx| {
3916 Thread::new_subagent(
3917 project.clone(),
3918 project_context,
3919 context_server_registry,
3920 Templates::new(),
3921 model.clone(),
3922 subagent_context,
3923 std::collections::BTreeMap::new(),
3924 cx,
3925 )
3926 });
3927
3928 subagent.read_with(cx, |thread, _| {
3929 assert!(thread.is_subagent());
3930 assert_eq!(thread.depth(), 1);
3931 assert!(thread.model().is_some());
3932 });
3933}
3934
3935#[gpui::test]
3936async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
3937 init_test(cx);
3938
3939 cx.update(|cx| {
3940 cx.update_flags(true, vec!["subagents".to_string()]);
3941 });
3942
3943 let fs = FakeFs::new(cx.executor());
3944 fs.insert_tree(path!("/test"), json!({})).await;
3945 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3946 let project_context = cx.new(|_cx| ProjectContext::default());
3947 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3948 let context_server_registry =
3949 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3950 let model = Arc::new(FakeLanguageModel::default());
3951
3952 let subagent_context = SubagentContext {
3953 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3954 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3955 depth: MAX_SUBAGENT_DEPTH,
3956 summary_prompt: "Summarize".to_string(),
3957 context_low_prompt: "Context low".to_string(),
3958 };
3959
3960 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3961 let environment = Rc::new(FakeThreadEnvironment { handle });
3962
3963 let deep_subagent = cx.new(|cx| {
3964 let mut thread = Thread::new_subagent(
3965 project.clone(),
3966 project_context,
3967 context_server_registry,
3968 Templates::new(),
3969 model.clone(),
3970 subagent_context,
3971 std::collections::BTreeMap::new(),
3972 cx,
3973 );
3974 thread.add_default_tools(environment, cx);
3975 thread
3976 });
3977
3978 deep_subagent.read_with(cx, |thread, _| {
3979 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
3980 assert!(
3981 !thread.has_registered_tool("subagent"),
3982 "subagent tool should not be present at max depth"
3983 );
3984 });
3985}
3986
3987#[gpui::test]
3988async fn test_subagent_receives_task_prompt(cx: &mut TestAppContext) {
3989 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3990 let fake_model = model.as_fake();
3991
3992 cx.update(|cx| {
3993 cx.update_flags(true, vec!["subagents".to_string()]);
3994 });
3995
3996 let subagent_context = SubagentContext {
3997 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3998 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3999 depth: 1,
4000 summary_prompt: "Summarize your work".to_string(),
4001 context_low_prompt: "Context low, wrap up".to_string(),
4002 };
4003
4004 let project = thread.read_with(cx, |t, _| t.project.clone());
4005 let project_context = cx.new(|_cx| ProjectContext::default());
4006 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4007 let context_server_registry =
4008 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4009
4010 let subagent = cx.new(|cx| {
4011 Thread::new_subagent(
4012 project.clone(),
4013 project_context,
4014 context_server_registry,
4015 Templates::new(),
4016 model.clone(),
4017 subagent_context,
4018 std::collections::BTreeMap::new(),
4019 cx,
4020 )
4021 });
4022
4023 let task_prompt = "Find all TODO comments in the codebase";
4024 subagent
4025 .update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))
4026 .unwrap();
4027 cx.run_until_parked();
4028
4029 let pending = fake_model.pending_completions();
4030 assert_eq!(pending.len(), 1, "should have one pending completion");
4031
4032 let messages = &pending[0].messages;
4033 let user_messages: Vec<_> = messages
4034 .iter()
4035 .filter(|m| m.role == language_model::Role::User)
4036 .collect();
4037 assert_eq!(user_messages.len(), 1, "should have one user message");
4038
4039 let content = &user_messages[0].content[0];
4040 assert!(
4041 content.to_str().unwrap().contains("TODO"),
4042 "task prompt should be in user message"
4043 );
4044}
4045
4046#[gpui::test]
4047async fn test_subagent_returns_summary_on_completion(cx: &mut TestAppContext) {
4048 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4049 let fake_model = model.as_fake();
4050
4051 cx.update(|cx| {
4052 cx.update_flags(true, vec!["subagents".to_string()]);
4053 });
4054
4055 let subagent_context = SubagentContext {
4056 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4057 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4058 depth: 1,
4059 summary_prompt: "Please summarize what you found".to_string(),
4060 context_low_prompt: "Context low, wrap up".to_string(),
4061 };
4062
4063 let project = thread.read_with(cx, |t, _| t.project.clone());
4064 let project_context = cx.new(|_cx| ProjectContext::default());
4065 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4066 let context_server_registry =
4067 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4068
4069 let subagent = cx.new(|cx| {
4070 Thread::new_subagent(
4071 project.clone(),
4072 project_context,
4073 context_server_registry,
4074 Templates::new(),
4075 model.clone(),
4076 subagent_context,
4077 std::collections::BTreeMap::new(),
4078 cx,
4079 )
4080 });
4081
4082 subagent
4083 .update(cx, |thread, cx| {
4084 thread.submit_user_message("Do some work", cx)
4085 })
4086 .unwrap();
4087 cx.run_until_parked();
4088
4089 fake_model.send_last_completion_stream_text_chunk("I did the work");
4090 fake_model
4091 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4092 fake_model.end_last_completion_stream();
4093 cx.run_until_parked();
4094
4095 subagent
4096 .update(cx, |thread, cx| thread.request_final_summary(cx))
4097 .unwrap();
4098 cx.run_until_parked();
4099
4100 let pending = fake_model.pending_completions();
4101 assert!(
4102 !pending.is_empty(),
4103 "should have pending completion for summary"
4104 );
4105
4106 let messages = &pending.last().unwrap().messages;
4107 let user_messages: Vec<_> = messages
4108 .iter()
4109 .filter(|m| m.role == language_model::Role::User)
4110 .collect();
4111
4112 let last_user = user_messages.last().unwrap();
4113 assert!(
4114 last_user.content[0].to_str().unwrap().contains("summarize"),
4115 "summary prompt should be sent"
4116 );
4117}
4118
4119#[gpui::test]
4120async fn test_allowed_tools_restricts_subagent_capabilities(cx: &mut TestAppContext) {
4121 init_test(cx);
4122
4123 cx.update(|cx| {
4124 cx.update_flags(true, vec!["subagents".to_string()]);
4125 });
4126
4127 let fs = FakeFs::new(cx.executor());
4128 fs.insert_tree(path!("/test"), json!({})).await;
4129 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4130 let project_context = cx.new(|_cx| ProjectContext::default());
4131 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4132 let context_server_registry =
4133 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4134 let model = Arc::new(FakeLanguageModel::default());
4135
4136 let subagent_context = SubagentContext {
4137 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4138 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4139 depth: 1,
4140 summary_prompt: "Summarize".to_string(),
4141 context_low_prompt: "Context low".to_string(),
4142 };
4143
4144 let subagent = cx.new(|cx| {
4145 let mut thread = Thread::new_subagent(
4146 project.clone(),
4147 project_context,
4148 context_server_registry,
4149 Templates::new(),
4150 model.clone(),
4151 subagent_context,
4152 std::collections::BTreeMap::new(),
4153 cx,
4154 );
4155 thread.add_tool(EchoTool);
4156 thread.add_tool(DelayTool);
4157 thread.add_tool(WordListTool);
4158 thread
4159 });
4160
4161 subagent.read_with(cx, |thread, _| {
4162 assert!(thread.has_registered_tool("echo"));
4163 assert!(thread.has_registered_tool("delay"));
4164 assert!(thread.has_registered_tool("word_list"));
4165 });
4166
4167 let allowed: collections::HashSet<gpui::SharedString> =
4168 vec!["echo".into()].into_iter().collect();
4169
4170 subagent.update(cx, |thread, _cx| {
4171 thread.restrict_tools(&allowed);
4172 });
4173
4174 subagent.read_with(cx, |thread, _| {
4175 assert!(
4176 thread.has_registered_tool("echo"),
4177 "echo should still be available"
4178 );
4179 assert!(
4180 !thread.has_registered_tool("delay"),
4181 "delay should be removed"
4182 );
4183 assert!(
4184 !thread.has_registered_tool("word_list"),
4185 "word_list should be removed"
4186 );
4187 });
4188}
4189
4190#[gpui::test]
4191async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4192 init_test(cx);
4193
4194 cx.update(|cx| {
4195 cx.update_flags(true, vec!["subagents".to_string()]);
4196 });
4197
4198 let fs = FakeFs::new(cx.executor());
4199 fs.insert_tree(path!("/test"), json!({})).await;
4200 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4201 let project_context = cx.new(|_cx| ProjectContext::default());
4202 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4203 let context_server_registry =
4204 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4205 let model = Arc::new(FakeLanguageModel::default());
4206
4207 let parent = cx.new(|cx| {
4208 Thread::new(
4209 project.clone(),
4210 project_context.clone(),
4211 context_server_registry.clone(),
4212 Templates::new(),
4213 Some(model.clone()),
4214 cx,
4215 )
4216 });
4217
4218 let subagent_context = SubagentContext {
4219 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4220 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4221 depth: 1,
4222 summary_prompt: "Summarize".to_string(),
4223 context_low_prompt: "Context low".to_string(),
4224 };
4225
4226 let subagent = cx.new(|cx| {
4227 Thread::new_subagent(
4228 project.clone(),
4229 project_context.clone(),
4230 context_server_registry.clone(),
4231 Templates::new(),
4232 model.clone(),
4233 subagent_context,
4234 std::collections::BTreeMap::new(),
4235 cx,
4236 )
4237 });
4238
4239 parent.update(cx, |thread, _cx| {
4240 thread.register_running_subagent(subagent.downgrade());
4241 });
4242
4243 subagent
4244 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4245 .unwrap();
4246 cx.run_until_parked();
4247
4248 subagent.read_with(cx, |thread, _| {
4249 assert!(!thread.is_turn_complete(), "subagent should be running");
4250 });
4251
4252 parent.update(cx, |thread, cx| {
4253 thread.cancel(cx).detach();
4254 });
4255
4256 subagent.read_with(cx, |thread, _| {
4257 assert!(
4258 thread.is_turn_complete(),
4259 "subagent should be cancelled when parent cancels"
4260 );
4261 });
4262}
4263
4264#[gpui::test]
4265async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4266 // This test verifies that the subagent tool properly handles user cancellation
4267 // via `event_stream.cancelled_by_user()` and stops all running subagents.
4268 init_test(cx);
4269 always_allow_tools(cx);
4270
4271 cx.update(|cx| {
4272 cx.update_flags(true, vec!["subagents".to_string()]);
4273 });
4274
4275 let fs = FakeFs::new(cx.executor());
4276 fs.insert_tree(path!("/test"), json!({})).await;
4277 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4278 let project_context = cx.new(|_cx| ProjectContext::default());
4279 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4280 let context_server_registry =
4281 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4282 let model = Arc::new(FakeLanguageModel::default());
4283
4284 let parent = cx.new(|cx| {
4285 Thread::new(
4286 project.clone(),
4287 project_context.clone(),
4288 context_server_registry.clone(),
4289 Templates::new(),
4290 Some(model.clone()),
4291 cx,
4292 )
4293 });
4294
4295 let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
4296 std::collections::BTreeMap::new();
4297
4298 #[allow(clippy::arc_with_non_send_sync)]
4299 let tool = Arc::new(SubagentTool::new(
4300 parent.downgrade(),
4301 project.clone(),
4302 project_context,
4303 context_server_registry,
4304 Templates::new(),
4305 0,
4306 parent_tools,
4307 ));
4308
4309 let (event_stream, _rx, mut cancellation_tx) =
4310 crate::ToolCallEventStream::test_with_cancellation();
4311
4312 // Start the subagent tool
4313 let task = cx.update(|cx| {
4314 tool.run(
4315 SubagentToolInput {
4316 subagents: vec![crate::SubagentConfig {
4317 label: "Long running task".to_string(),
4318 task_prompt: "Do a very long task that takes forever".to_string(),
4319 summary_prompt: "Summarize".to_string(),
4320 context_low_prompt: "Context low".to_string(),
4321 timeout_ms: None,
4322 allowed_tools: None,
4323 }],
4324 },
4325 event_stream.clone(),
4326 cx,
4327 )
4328 });
4329
4330 cx.run_until_parked();
4331
4332 // Signal cancellation via the event stream
4333 crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4334
4335 // The task should complete promptly with a cancellation error
4336 let timeout = cx.background_executor.timer(Duration::from_secs(5));
4337 let result = futures::select! {
4338 result = task.fuse() => result,
4339 _ = timeout.fuse() => {
4340 panic!("subagent tool did not respond to cancellation within timeout");
4341 }
4342 };
4343
4344 // Verify we got a cancellation error
4345 let err = result.unwrap_err();
4346 assert!(
4347 err.to_string().contains("cancelled by user"),
4348 "expected cancellation error, got: {}",
4349 err
4350 );
4351}
4352
4353#[gpui::test]
4354async fn test_subagent_model_error_returned_as_tool_error(cx: &mut TestAppContext) {
4355 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4356 let fake_model = model.as_fake();
4357
4358 cx.update(|cx| {
4359 cx.update_flags(true, vec!["subagents".to_string()]);
4360 });
4361
4362 let subagent_context = SubagentContext {
4363 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4364 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4365 depth: 1,
4366 summary_prompt: "Summarize".to_string(),
4367 context_low_prompt: "Context low".to_string(),
4368 };
4369
4370 let project = thread.read_with(cx, |t, _| t.project.clone());
4371 let project_context = cx.new(|_cx| ProjectContext::default());
4372 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4373 let context_server_registry =
4374 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4375
4376 let subagent = cx.new(|cx| {
4377 Thread::new_subagent(
4378 project.clone(),
4379 project_context,
4380 context_server_registry,
4381 Templates::new(),
4382 model.clone(),
4383 subagent_context,
4384 std::collections::BTreeMap::new(),
4385 cx,
4386 )
4387 });
4388
4389 subagent
4390 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4391 .unwrap();
4392 cx.run_until_parked();
4393
4394 subagent.read_with(cx, |thread, _| {
4395 assert!(!thread.is_turn_complete(), "turn should be in progress");
4396 });
4397
4398 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::NoApiKey {
4399 provider: LanguageModelProviderName::from("Fake".to_string()),
4400 });
4401 fake_model.end_last_completion_stream();
4402 cx.run_until_parked();
4403
4404 subagent.read_with(cx, |thread, _| {
4405 assert!(
4406 thread.is_turn_complete(),
4407 "turn should be complete after non-retryable error"
4408 );
4409 });
4410}
4411
4412#[gpui::test]
4413async fn test_subagent_timeout_triggers_early_summary(cx: &mut TestAppContext) {
4414 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4415 let fake_model = model.as_fake();
4416
4417 cx.update(|cx| {
4418 cx.update_flags(true, vec!["subagents".to_string()]);
4419 });
4420
4421 let subagent_context = SubagentContext {
4422 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4423 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4424 depth: 1,
4425 summary_prompt: "Summarize your work".to_string(),
4426 context_low_prompt: "Context low, stop and summarize".to_string(),
4427 };
4428
4429 let project = thread.read_with(cx, |t, _| t.project.clone());
4430 let project_context = cx.new(|_cx| ProjectContext::default());
4431 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4432 let context_server_registry =
4433 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4434
4435 let subagent = cx.new(|cx| {
4436 Thread::new_subagent(
4437 project.clone(),
4438 project_context.clone(),
4439 context_server_registry.clone(),
4440 Templates::new(),
4441 model.clone(),
4442 subagent_context.clone(),
4443 std::collections::BTreeMap::new(),
4444 cx,
4445 )
4446 });
4447
4448 subagent.update(cx, |thread, _| {
4449 thread.add_tool(EchoTool);
4450 });
4451
4452 subagent
4453 .update(cx, |thread, cx| {
4454 thread.submit_user_message("Do some work", cx)
4455 })
4456 .unwrap();
4457 cx.run_until_parked();
4458
4459 fake_model.send_last_completion_stream_text_chunk("Working on it...");
4460 fake_model
4461 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4462 fake_model.end_last_completion_stream();
4463 cx.run_until_parked();
4464
4465 let interrupt_result = subagent.update(cx, |thread, cx| thread.interrupt_for_summary(cx));
4466 assert!(
4467 interrupt_result.is_ok(),
4468 "interrupt_for_summary should succeed"
4469 );
4470
4471 cx.run_until_parked();
4472
4473 let pending = fake_model.pending_completions();
4474 assert!(
4475 !pending.is_empty(),
4476 "should have pending completion for interrupted summary"
4477 );
4478
4479 let messages = &pending.last().unwrap().messages;
4480 let user_messages: Vec<_> = messages
4481 .iter()
4482 .filter(|m| m.role == language_model::Role::User)
4483 .collect();
4484
4485 let last_user = user_messages.last().unwrap();
4486 let content_str = last_user.content[0].to_str().unwrap();
4487 assert!(
4488 content_str.contains("Context low") || content_str.contains("stop and summarize"),
4489 "context_low_prompt should be sent when interrupting: got {:?}",
4490 content_str
4491 );
4492}
4493
4494#[gpui::test]
4495async fn test_context_low_check_returns_true_when_usage_high(cx: &mut TestAppContext) {
4496 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4497 let fake_model = model.as_fake();
4498
4499 cx.update(|cx| {
4500 cx.update_flags(true, vec!["subagents".to_string()]);
4501 });
4502
4503 let subagent_context = SubagentContext {
4504 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4505 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4506 depth: 1,
4507 summary_prompt: "Summarize".to_string(),
4508 context_low_prompt: "Context low".to_string(),
4509 };
4510
4511 let project = thread.read_with(cx, |t, _| t.project.clone());
4512 let project_context = cx.new(|_cx| ProjectContext::default());
4513 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4514 let context_server_registry =
4515 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4516
4517 let subagent = cx.new(|cx| {
4518 Thread::new_subagent(
4519 project.clone(),
4520 project_context,
4521 context_server_registry,
4522 Templates::new(),
4523 model.clone(),
4524 subagent_context,
4525 std::collections::BTreeMap::new(),
4526 cx,
4527 )
4528 });
4529
4530 subagent
4531 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4532 .unwrap();
4533 cx.run_until_parked();
4534
4535 let max_tokens = model.max_token_count();
4536 let high_usage = language_model::TokenUsage {
4537 input_tokens: (max_tokens as f64 * 0.80) as u64,
4538 output_tokens: 0,
4539 cache_creation_input_tokens: 0,
4540 cache_read_input_tokens: 0,
4541 };
4542
4543 fake_model
4544 .send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(high_usage));
4545 fake_model.send_last_completion_stream_text_chunk("Working...");
4546 fake_model
4547 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4548 fake_model.end_last_completion_stream();
4549 cx.run_until_parked();
4550
4551 let usage = subagent.read_with(cx, |thread, _| thread.latest_token_usage());
4552 assert!(usage.is_some(), "should have token usage after completion");
4553
4554 let usage = usage.unwrap();
4555 let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
4556 assert!(
4557 remaining_ratio <= 0.25,
4558 "remaining ratio should be at or below 25% (got {}%), indicating context is low",
4559 remaining_ratio * 100.0
4560 );
4561}
4562
4563#[gpui::test]
4564async fn test_allowed_tools_rejects_unknown_tool(cx: &mut TestAppContext) {
4565 init_test(cx);
4566
4567 cx.update(|cx| {
4568 cx.update_flags(true, vec!["subagents".to_string()]);
4569 });
4570
4571 let fs = FakeFs::new(cx.executor());
4572 fs.insert_tree(path!("/test"), json!({})).await;
4573 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4574 let project_context = cx.new(|_cx| ProjectContext::default());
4575 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4576 let context_server_registry =
4577 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4578 let model = Arc::new(FakeLanguageModel::default());
4579
4580 let parent = cx.new(|cx| {
4581 let mut thread = Thread::new(
4582 project.clone(),
4583 project_context.clone(),
4584 context_server_registry.clone(),
4585 Templates::new(),
4586 Some(model.clone()),
4587 cx,
4588 );
4589 thread.add_tool(EchoTool);
4590 thread
4591 });
4592
4593 let mut parent_tools: std::collections::BTreeMap<
4594 gpui::SharedString,
4595 Arc<dyn crate::AnyAgentTool>,
4596 > = std::collections::BTreeMap::new();
4597 parent_tools.insert("echo".into(), EchoTool.erase());
4598
4599 #[allow(clippy::arc_with_non_send_sync)]
4600 let tool = Arc::new(SubagentTool::new(
4601 parent.downgrade(),
4602 project,
4603 project_context,
4604 context_server_registry,
4605 Templates::new(),
4606 0,
4607 parent_tools,
4608 ));
4609
4610 let subagent_configs = vec![crate::SubagentConfig {
4611 label: "Test".to_string(),
4612 task_prompt: "Do something".to_string(),
4613 summary_prompt: "Summarize".to_string(),
4614 context_low_prompt: "Context low".to_string(),
4615 timeout_ms: None,
4616 allowed_tools: Some(vec!["nonexistent_tool".to_string()]),
4617 }];
4618 let result = tool.validate_subagents(&subagent_configs);
4619 assert!(result.is_err(), "should reject unknown tool");
4620 let err_msg = result.unwrap_err().to_string();
4621 assert!(
4622 err_msg.contains("nonexistent_tool"),
4623 "error should mention the invalid tool name: {}",
4624 err_msg
4625 );
4626 assert!(
4627 err_msg.contains("do not exist"),
4628 "error should explain the tool does not exist: {}",
4629 err_msg
4630 );
4631}
4632
4633#[gpui::test]
4634async fn test_subagent_empty_response_handled(cx: &mut TestAppContext) {
4635 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4636 let fake_model = model.as_fake();
4637
4638 cx.update(|cx| {
4639 cx.update_flags(true, vec!["subagents".to_string()]);
4640 });
4641
4642 let subagent_context = SubagentContext {
4643 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4644 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4645 depth: 1,
4646 summary_prompt: "Summarize".to_string(),
4647 context_low_prompt: "Context low".to_string(),
4648 };
4649
4650 let project = thread.read_with(cx, |t, _| t.project.clone());
4651 let project_context = cx.new(|_cx| ProjectContext::default());
4652 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4653 let context_server_registry =
4654 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4655
4656 let subagent = cx.new(|cx| {
4657 Thread::new_subagent(
4658 project.clone(),
4659 project_context,
4660 context_server_registry,
4661 Templates::new(),
4662 model.clone(),
4663 subagent_context,
4664 std::collections::BTreeMap::new(),
4665 cx,
4666 )
4667 });
4668
4669 subagent
4670 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4671 .unwrap();
4672 cx.run_until_parked();
4673
4674 fake_model
4675 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4676 fake_model.end_last_completion_stream();
4677 cx.run_until_parked();
4678
4679 subagent.read_with(cx, |thread, _| {
4680 assert!(
4681 thread.is_turn_complete(),
4682 "turn should complete even with empty response"
4683 );
4684 });
4685}
4686
4687#[gpui::test]
4688async fn test_nested_subagent_at_depth_2_succeeds(cx: &mut TestAppContext) {
4689 init_test(cx);
4690
4691 cx.update(|cx| {
4692 cx.update_flags(true, vec!["subagents".to_string()]);
4693 });
4694
4695 let fs = FakeFs::new(cx.executor());
4696 fs.insert_tree(path!("/test"), json!({})).await;
4697 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4698 let project_context = cx.new(|_cx| ProjectContext::default());
4699 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4700 let context_server_registry =
4701 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4702 let model = Arc::new(FakeLanguageModel::default());
4703
4704 let depth_1_context = SubagentContext {
4705 parent_thread_id: agent_client_protocol::SessionId::new("root-id"),
4706 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-1"),
4707 depth: 1,
4708 summary_prompt: "Summarize".to_string(),
4709 context_low_prompt: "Context low".to_string(),
4710 };
4711
4712 let depth_1_subagent = cx.new(|cx| {
4713 Thread::new_subagent(
4714 project.clone(),
4715 project_context.clone(),
4716 context_server_registry.clone(),
4717 Templates::new(),
4718 model.clone(),
4719 depth_1_context,
4720 std::collections::BTreeMap::new(),
4721 cx,
4722 )
4723 });
4724
4725 depth_1_subagent.read_with(cx, |thread, _| {
4726 assert_eq!(thread.depth(), 1);
4727 assert!(thread.is_subagent());
4728 });
4729
4730 let depth_2_context = SubagentContext {
4731 parent_thread_id: agent_client_protocol::SessionId::new("depth-1-id"),
4732 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-2"),
4733 depth: 2,
4734 summary_prompt: "Summarize depth 2".to_string(),
4735 context_low_prompt: "Context low depth 2".to_string(),
4736 };
4737
4738 let depth_2_subagent = cx.new(|cx| {
4739 Thread::new_subagent(
4740 project.clone(),
4741 project_context.clone(),
4742 context_server_registry.clone(),
4743 Templates::new(),
4744 model.clone(),
4745 depth_2_context,
4746 std::collections::BTreeMap::new(),
4747 cx,
4748 )
4749 });
4750
4751 depth_2_subagent.read_with(cx, |thread, _| {
4752 assert_eq!(thread.depth(), 2);
4753 assert!(thread.is_subagent());
4754 });
4755
4756 depth_2_subagent
4757 .update(cx, |thread, cx| {
4758 thread.submit_user_message("Nested task", cx)
4759 })
4760 .unwrap();
4761 cx.run_until_parked();
4762
4763 let pending = model.as_fake().pending_completions();
4764 assert!(
4765 !pending.is_empty(),
4766 "depth-2 subagent should be able to submit messages"
4767 );
4768}
4769
4770#[gpui::test]
4771async fn test_subagent_uses_tool_and_returns_result(cx: &mut TestAppContext) {
4772 init_test(cx);
4773 always_allow_tools(cx);
4774
4775 cx.update(|cx| {
4776 cx.update_flags(true, vec!["subagents".to_string()]);
4777 });
4778
4779 let fs = FakeFs::new(cx.executor());
4780 fs.insert_tree(path!("/test"), json!({})).await;
4781 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4782 let project_context = cx.new(|_cx| ProjectContext::default());
4783 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4784 let context_server_registry =
4785 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4786 let model = Arc::new(FakeLanguageModel::default());
4787 let fake_model = model.as_fake();
4788
4789 let subagent_context = SubagentContext {
4790 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4791 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4792 depth: 1,
4793 summary_prompt: "Summarize what you did".to_string(),
4794 context_low_prompt: "Context low".to_string(),
4795 };
4796
4797 let subagent = cx.new(|cx| {
4798 let mut thread = Thread::new_subagent(
4799 project.clone(),
4800 project_context,
4801 context_server_registry,
4802 Templates::new(),
4803 model.clone(),
4804 subagent_context,
4805 std::collections::BTreeMap::new(),
4806 cx,
4807 );
4808 thread.add_tool(EchoTool);
4809 thread
4810 });
4811
4812 subagent.read_with(cx, |thread, _| {
4813 assert!(
4814 thread.has_registered_tool("echo"),
4815 "subagent should have echo tool"
4816 );
4817 });
4818
4819 subagent
4820 .update(cx, |thread, cx| {
4821 thread.submit_user_message("Use the echo tool to echo 'hello world'", cx)
4822 })
4823 .unwrap();
4824 cx.run_until_parked();
4825
4826 let tool_use = LanguageModelToolUse {
4827 id: "tool_call_1".into(),
4828 name: EchoTool::name().into(),
4829 raw_input: json!({"text": "hello world"}).to_string(),
4830 input: json!({"text": "hello world"}),
4831 is_input_complete: true,
4832 thought_signature: None,
4833 };
4834 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
4835 fake_model.end_last_completion_stream();
4836 cx.run_until_parked();
4837
4838 let pending = fake_model.pending_completions();
4839 assert!(
4840 !pending.is_empty(),
4841 "should have pending completion after tool use"
4842 );
4843
4844 let last_completion = pending.last().unwrap();
4845 let has_tool_result = last_completion.messages.iter().any(|m| {
4846 m.content
4847 .iter()
4848 .any(|c| matches!(c, MessageContent::ToolResult(_)))
4849 });
4850 assert!(
4851 has_tool_result,
4852 "tool result should be in the messages sent back to the model"
4853 );
4854}
4855
4856#[gpui::test]
4857async fn test_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4858 init_test(cx);
4859
4860 cx.update(|cx| {
4861 cx.update_flags(true, vec!["subagents".to_string()]);
4862 });
4863
4864 let fs = FakeFs::new(cx.executor());
4865 fs.insert_tree(path!("/test"), json!({})).await;
4866 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4867 let project_context = cx.new(|_cx| ProjectContext::default());
4868 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4869 let context_server_registry =
4870 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4871 let model = Arc::new(FakeLanguageModel::default());
4872
4873 let parent = cx.new(|cx| {
4874 Thread::new(
4875 project.clone(),
4876 project_context.clone(),
4877 context_server_registry.clone(),
4878 Templates::new(),
4879 Some(model.clone()),
4880 cx,
4881 )
4882 });
4883
4884 let mut subagents = Vec::new();
4885 for i in 0..MAX_PARALLEL_SUBAGENTS {
4886 let subagent_context = SubagentContext {
4887 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4888 tool_use_id: language_model::LanguageModelToolUseId::from(format!("tool-use-{}", i)),
4889 depth: 1,
4890 summary_prompt: "Summarize".to_string(),
4891 context_low_prompt: "Context low".to_string(),
4892 };
4893
4894 let subagent = cx.new(|cx| {
4895 Thread::new_subagent(
4896 project.clone(),
4897 project_context.clone(),
4898 context_server_registry.clone(),
4899 Templates::new(),
4900 model.clone(),
4901 subagent_context,
4902 std::collections::BTreeMap::new(),
4903 cx,
4904 )
4905 });
4906
4907 parent.update(cx, |thread, _cx| {
4908 thread.register_running_subagent(subagent.downgrade());
4909 });
4910 subagents.push(subagent);
4911 }
4912
4913 parent.read_with(cx, |thread, _| {
4914 assert_eq!(
4915 thread.running_subagent_count(),
4916 MAX_PARALLEL_SUBAGENTS,
4917 "should have MAX_PARALLEL_SUBAGENTS registered"
4918 );
4919 });
4920
4921 let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
4922 std::collections::BTreeMap::new();
4923
4924 #[allow(clippy::arc_with_non_send_sync)]
4925 let tool = Arc::new(SubagentTool::new(
4926 parent.downgrade(),
4927 project.clone(),
4928 project_context,
4929 context_server_registry,
4930 Templates::new(),
4931 0,
4932 parent_tools,
4933 ));
4934
4935 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4936
4937 let result = cx.update(|cx| {
4938 tool.run(
4939 SubagentToolInput {
4940 subagents: vec![crate::SubagentConfig {
4941 label: "Test".to_string(),
4942 task_prompt: "Do something".to_string(),
4943 summary_prompt: "Summarize".to_string(),
4944 context_low_prompt: "Context low".to_string(),
4945 timeout_ms: None,
4946 allowed_tools: None,
4947 }],
4948 },
4949 event_stream,
4950 cx,
4951 )
4952 });
4953
4954 let err = result.await.unwrap_err();
4955 assert!(
4956 err.to_string().contains("Maximum parallel subagents"),
4957 "should reject when max parallel subagents reached: {}",
4958 err
4959 );
4960
4961 drop(subagents);
4962}
4963
4964#[gpui::test]
4965async fn test_subagent_tool_end_to_end(cx: &mut TestAppContext) {
4966 init_test(cx);
4967 always_allow_tools(cx);
4968
4969 cx.update(|cx| {
4970 cx.update_flags(true, vec!["subagents".to_string()]);
4971 });
4972
4973 let fs = FakeFs::new(cx.executor());
4974 fs.insert_tree(path!("/test"), json!({})).await;
4975 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4976 let project_context = cx.new(|_cx| ProjectContext::default());
4977 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4978 let context_server_registry =
4979 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4980 let model = Arc::new(FakeLanguageModel::default());
4981 let fake_model = model.as_fake();
4982
4983 let parent = cx.new(|cx| {
4984 let mut thread = Thread::new(
4985 project.clone(),
4986 project_context.clone(),
4987 context_server_registry.clone(),
4988 Templates::new(),
4989 Some(model.clone()),
4990 cx,
4991 );
4992 thread.add_tool(EchoTool);
4993 thread
4994 });
4995
4996 let mut parent_tools: std::collections::BTreeMap<
4997 gpui::SharedString,
4998 Arc<dyn crate::AnyAgentTool>,
4999 > = std::collections::BTreeMap::new();
5000 parent_tools.insert("echo".into(), EchoTool.erase());
5001
5002 #[allow(clippy::arc_with_non_send_sync)]
5003 let tool = Arc::new(SubagentTool::new(
5004 parent.downgrade(),
5005 project.clone(),
5006 project_context,
5007 context_server_registry,
5008 Templates::new(),
5009 0,
5010 parent_tools,
5011 ));
5012
5013 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5014
5015 let task = cx.update(|cx| {
5016 tool.run(
5017 SubagentToolInput {
5018 subagents: vec![crate::SubagentConfig {
5019 label: "Research task".to_string(),
5020 task_prompt: "Find all TODOs in the codebase".to_string(),
5021 summary_prompt: "Summarize what you found".to_string(),
5022 context_low_prompt: "Context low, wrap up".to_string(),
5023 timeout_ms: None,
5024 allowed_tools: None,
5025 }],
5026 },
5027 event_stream,
5028 cx,
5029 )
5030 });
5031
5032 cx.run_until_parked();
5033
5034 let pending = fake_model.pending_completions();
5035 assert!(
5036 !pending.is_empty(),
5037 "subagent should have started and sent a completion request"
5038 );
5039
5040 let first_completion = &pending[0];
5041 let has_task_prompt = first_completion.messages.iter().any(|m| {
5042 m.role == language_model::Role::User
5043 && m.content
5044 .iter()
5045 .any(|c| c.to_str().map(|s| s.contains("TODO")).unwrap_or(false))
5046 });
5047 assert!(has_task_prompt, "task prompt should be sent to subagent");
5048
5049 fake_model.send_last_completion_stream_text_chunk("I found 5 TODOs in the codebase.");
5050 fake_model
5051 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
5052 fake_model.end_last_completion_stream();
5053 cx.run_until_parked();
5054
5055 let pending = fake_model.pending_completions();
5056 assert!(
5057 !pending.is_empty(),
5058 "should have pending completion for summary request"
5059 );
5060
5061 let last_completion = pending.last().unwrap();
5062 let has_summary_prompt = last_completion.messages.iter().any(|m| {
5063 m.role == language_model::Role::User
5064 && m.content.iter().any(|c| {
5065 c.to_str()
5066 .map(|s| s.contains("Summarize") || s.contains("summarize"))
5067 .unwrap_or(false)
5068 })
5069 });
5070 assert!(
5071 has_summary_prompt,
5072 "summary prompt should be sent after task completion"
5073 );
5074
5075 fake_model.send_last_completion_stream_text_chunk("Summary: Found 5 TODOs across 3 files.");
5076 fake_model
5077 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
5078 fake_model.end_last_completion_stream();
5079 cx.run_until_parked();
5080
5081 let result = task.await;
5082 assert!(result.is_ok(), "subagent tool should complete successfully");
5083
5084 let summary = result.unwrap();
5085 assert!(
5086 summary.contains("Summary") || summary.contains("TODO") || summary.contains("5"),
5087 "summary should contain subagent's response: {}",
5088 summary
5089 );
5090}
5091
5092#[gpui::test]
5093async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5094 init_test(cx);
5095
5096 let fs = FakeFs::new(cx.executor());
5097 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5098 .await;
5099 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5100
5101 cx.update(|cx| {
5102 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5103 settings.tool_permissions.tools.insert(
5104 "edit_file".into(),
5105 agent_settings::ToolRules {
5106 default_mode: settings::ToolPermissionMode::Allow,
5107 always_allow: vec![],
5108 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5109 always_confirm: vec![],
5110 invalid_patterns: vec![],
5111 },
5112 );
5113 agent_settings::AgentSettings::override_global(settings, cx);
5114 });
5115
5116 let context_server_registry =
5117 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5118 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5119 let templates = crate::Templates::new();
5120 let thread = cx.new(|cx| {
5121 crate::Thread::new(
5122 project.clone(),
5123 cx.new(|_cx| prompt_store::ProjectContext::default()),
5124 context_server_registry,
5125 templates.clone(),
5126 None,
5127 cx,
5128 )
5129 });
5130
5131 #[allow(clippy::arc_with_non_send_sync)]
5132 let tool = Arc::new(crate::EditFileTool::new(
5133 project.clone(),
5134 thread.downgrade(),
5135 language_registry,
5136 templates,
5137 ));
5138 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5139
5140 let task = cx.update(|cx| {
5141 tool.run(
5142 crate::EditFileToolInput {
5143 display_description: "Edit sensitive file".to_string(),
5144 path: "root/sensitive_config.txt".into(),
5145 mode: crate::EditFileMode::Edit,
5146 },
5147 event_stream,
5148 cx,
5149 )
5150 });
5151
5152 let result = task.await;
5153 assert!(result.is_err(), "expected edit to be blocked");
5154 assert!(
5155 result.unwrap_err().to_string().contains("blocked"),
5156 "error should mention the edit was blocked"
5157 );
5158}
5159
5160#[gpui::test]
5161async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5162 init_test(cx);
5163
5164 let fs = FakeFs::new(cx.executor());
5165 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5166 .await;
5167 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5168
5169 cx.update(|cx| {
5170 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5171 settings.tool_permissions.tools.insert(
5172 "delete_path".into(),
5173 agent_settings::ToolRules {
5174 default_mode: settings::ToolPermissionMode::Allow,
5175 always_allow: vec![],
5176 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5177 always_confirm: vec![],
5178 invalid_patterns: vec![],
5179 },
5180 );
5181 agent_settings::AgentSettings::override_global(settings, cx);
5182 });
5183
5184 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5185
5186 #[allow(clippy::arc_with_non_send_sync)]
5187 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5188 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5189
5190 let task = cx.update(|cx| {
5191 tool.run(
5192 crate::DeletePathToolInput {
5193 path: "root/important_data.txt".to_string(),
5194 },
5195 event_stream,
5196 cx,
5197 )
5198 });
5199
5200 let result = task.await;
5201 assert!(result.is_err(), "expected deletion to be blocked");
5202 assert!(
5203 result.unwrap_err().to_string().contains("blocked"),
5204 "error should mention the deletion was blocked"
5205 );
5206}
5207
5208#[gpui::test]
5209async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5210 init_test(cx);
5211
5212 let fs = FakeFs::new(cx.executor());
5213 fs.insert_tree(
5214 "/root",
5215 json!({
5216 "safe.txt": "content",
5217 "protected": {}
5218 }),
5219 )
5220 .await;
5221 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5222
5223 cx.update(|cx| {
5224 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5225 settings.tool_permissions.tools.insert(
5226 "move_path".into(),
5227 agent_settings::ToolRules {
5228 default_mode: settings::ToolPermissionMode::Allow,
5229 always_allow: vec![],
5230 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5231 always_confirm: vec![],
5232 invalid_patterns: vec![],
5233 },
5234 );
5235 agent_settings::AgentSettings::override_global(settings, cx);
5236 });
5237
5238 #[allow(clippy::arc_with_non_send_sync)]
5239 let tool = Arc::new(crate::MovePathTool::new(project));
5240 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5241
5242 let task = cx.update(|cx| {
5243 tool.run(
5244 crate::MovePathToolInput {
5245 source_path: "root/safe.txt".to_string(),
5246 destination_path: "root/protected/safe.txt".to_string(),
5247 },
5248 event_stream,
5249 cx,
5250 )
5251 });
5252
5253 let result = task.await;
5254 assert!(
5255 result.is_err(),
5256 "expected move to be blocked due to destination path"
5257 );
5258 assert!(
5259 result.unwrap_err().to_string().contains("blocked"),
5260 "error should mention the move was blocked"
5261 );
5262}
5263
5264#[gpui::test]
5265async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5266 init_test(cx);
5267
5268 let fs = FakeFs::new(cx.executor());
5269 fs.insert_tree(
5270 "/root",
5271 json!({
5272 "secret.txt": "secret content",
5273 "public": {}
5274 }),
5275 )
5276 .await;
5277 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5278
5279 cx.update(|cx| {
5280 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5281 settings.tool_permissions.tools.insert(
5282 "move_path".into(),
5283 agent_settings::ToolRules {
5284 default_mode: settings::ToolPermissionMode::Allow,
5285 always_allow: vec![],
5286 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5287 always_confirm: vec![],
5288 invalid_patterns: vec![],
5289 },
5290 );
5291 agent_settings::AgentSettings::override_global(settings, cx);
5292 });
5293
5294 #[allow(clippy::arc_with_non_send_sync)]
5295 let tool = Arc::new(crate::MovePathTool::new(project));
5296 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5297
5298 let task = cx.update(|cx| {
5299 tool.run(
5300 crate::MovePathToolInput {
5301 source_path: "root/secret.txt".to_string(),
5302 destination_path: "root/public/not_secret.txt".to_string(),
5303 },
5304 event_stream,
5305 cx,
5306 )
5307 });
5308
5309 let result = task.await;
5310 assert!(
5311 result.is_err(),
5312 "expected move to be blocked due to source path"
5313 );
5314 assert!(
5315 result.unwrap_err().to_string().contains("blocked"),
5316 "error should mention the move was blocked"
5317 );
5318}
5319
5320#[gpui::test]
5321async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5322 init_test(cx);
5323
5324 let fs = FakeFs::new(cx.executor());
5325 fs.insert_tree(
5326 "/root",
5327 json!({
5328 "confidential.txt": "confidential data",
5329 "dest": {}
5330 }),
5331 )
5332 .await;
5333 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5334
5335 cx.update(|cx| {
5336 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5337 settings.tool_permissions.tools.insert(
5338 "copy_path".into(),
5339 agent_settings::ToolRules {
5340 default_mode: settings::ToolPermissionMode::Allow,
5341 always_allow: vec![],
5342 always_deny: vec![
5343 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5344 ],
5345 always_confirm: vec![],
5346 invalid_patterns: vec![],
5347 },
5348 );
5349 agent_settings::AgentSettings::override_global(settings, cx);
5350 });
5351
5352 #[allow(clippy::arc_with_non_send_sync)]
5353 let tool = Arc::new(crate::CopyPathTool::new(project));
5354 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5355
5356 let task = cx.update(|cx| {
5357 tool.run(
5358 crate::CopyPathToolInput {
5359 source_path: "root/confidential.txt".to_string(),
5360 destination_path: "root/dest/copy.txt".to_string(),
5361 },
5362 event_stream,
5363 cx,
5364 )
5365 });
5366
5367 let result = task.await;
5368 assert!(result.is_err(), "expected copy to be blocked");
5369 assert!(
5370 result.unwrap_err().to_string().contains("blocked"),
5371 "error should mention the copy was blocked"
5372 );
5373}
5374
5375#[gpui::test]
5376async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5377 init_test(cx);
5378
5379 let fs = FakeFs::new(cx.executor());
5380 fs.insert_tree(
5381 "/root",
5382 json!({
5383 "normal.txt": "normal content",
5384 "readonly": {
5385 "config.txt": "readonly content"
5386 }
5387 }),
5388 )
5389 .await;
5390 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5391
5392 cx.update(|cx| {
5393 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5394 settings.tool_permissions.tools.insert(
5395 "save_file".into(),
5396 agent_settings::ToolRules {
5397 default_mode: settings::ToolPermissionMode::Allow,
5398 always_allow: vec![],
5399 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5400 always_confirm: vec![],
5401 invalid_patterns: vec![],
5402 },
5403 );
5404 agent_settings::AgentSettings::override_global(settings, cx);
5405 });
5406
5407 #[allow(clippy::arc_with_non_send_sync)]
5408 let tool = Arc::new(crate::SaveFileTool::new(project));
5409 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5410
5411 let task = cx.update(|cx| {
5412 tool.run(
5413 crate::SaveFileToolInput {
5414 paths: vec![
5415 std::path::PathBuf::from("root/normal.txt"),
5416 std::path::PathBuf::from("root/readonly/config.txt"),
5417 ],
5418 },
5419 event_stream,
5420 cx,
5421 )
5422 });
5423
5424 let result = task.await;
5425 assert!(
5426 result.is_err(),
5427 "expected save to be blocked due to denied path"
5428 );
5429 assert!(
5430 result.unwrap_err().to_string().contains("blocked"),
5431 "error should mention the save was blocked"
5432 );
5433}
5434
5435#[gpui::test]
5436async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5437 init_test(cx);
5438
5439 let fs = FakeFs::new(cx.executor());
5440 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5441 .await;
5442 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5443
5444 cx.update(|cx| {
5445 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5446 settings.always_allow_tool_actions = false;
5447 settings.tool_permissions.tools.insert(
5448 "save_file".into(),
5449 agent_settings::ToolRules {
5450 default_mode: settings::ToolPermissionMode::Allow,
5451 always_allow: vec![],
5452 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5453 always_confirm: vec![],
5454 invalid_patterns: vec![],
5455 },
5456 );
5457 agent_settings::AgentSettings::override_global(settings, cx);
5458 });
5459
5460 #[allow(clippy::arc_with_non_send_sync)]
5461 let tool = Arc::new(crate::SaveFileTool::new(project));
5462 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5463
5464 let task = cx.update(|cx| {
5465 tool.run(
5466 crate::SaveFileToolInput {
5467 paths: vec![std::path::PathBuf::from("root/config.secret")],
5468 },
5469 event_stream,
5470 cx,
5471 )
5472 });
5473
5474 let result = task.await;
5475 assert!(result.is_err(), "expected save to be blocked");
5476 assert!(
5477 result.unwrap_err().to_string().contains("blocked"),
5478 "error should mention the save was blocked"
5479 );
5480}
5481
5482#[gpui::test]
5483async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5484 init_test(cx);
5485
5486 cx.update(|cx| {
5487 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5488 settings.tool_permissions.tools.insert(
5489 "web_search".into(),
5490 agent_settings::ToolRules {
5491 default_mode: settings::ToolPermissionMode::Allow,
5492 always_allow: vec![],
5493 always_deny: vec![
5494 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5495 ],
5496 always_confirm: vec![],
5497 invalid_patterns: vec![],
5498 },
5499 );
5500 agent_settings::AgentSettings::override_global(settings, cx);
5501 });
5502
5503 #[allow(clippy::arc_with_non_send_sync)]
5504 let tool = Arc::new(crate::WebSearchTool);
5505 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5506
5507 let input: crate::WebSearchToolInput =
5508 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5509
5510 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5511
5512 let result = task.await;
5513 assert!(result.is_err(), "expected search to be blocked");
5514 assert!(
5515 result.unwrap_err().to_string().contains("blocked"),
5516 "error should mention the search was blocked"
5517 );
5518}
5519
5520#[gpui::test]
5521async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5522 init_test(cx);
5523
5524 let fs = FakeFs::new(cx.executor());
5525 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5526 .await;
5527 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5528
5529 cx.update(|cx| {
5530 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5531 settings.always_allow_tool_actions = false;
5532 settings.tool_permissions.tools.insert(
5533 "edit_file".into(),
5534 agent_settings::ToolRules {
5535 default_mode: settings::ToolPermissionMode::Confirm,
5536 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5537 always_deny: vec![],
5538 always_confirm: vec![],
5539 invalid_patterns: vec![],
5540 },
5541 );
5542 agent_settings::AgentSettings::override_global(settings, cx);
5543 });
5544
5545 let context_server_registry =
5546 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5547 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5548 let templates = crate::Templates::new();
5549 let thread = cx.new(|cx| {
5550 crate::Thread::new(
5551 project.clone(),
5552 cx.new(|_cx| prompt_store::ProjectContext::default()),
5553 context_server_registry,
5554 templates.clone(),
5555 None,
5556 cx,
5557 )
5558 });
5559
5560 #[allow(clippy::arc_with_non_send_sync)]
5561 let tool = Arc::new(crate::EditFileTool::new(
5562 project,
5563 thread.downgrade(),
5564 language_registry,
5565 templates,
5566 ));
5567 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5568
5569 let _task = cx.update(|cx| {
5570 tool.run(
5571 crate::EditFileToolInput {
5572 display_description: "Edit README".to_string(),
5573 path: "root/README.md".into(),
5574 mode: crate::EditFileMode::Edit,
5575 },
5576 event_stream,
5577 cx,
5578 )
5579 });
5580
5581 cx.run_until_parked();
5582
5583 let event = rx.try_next();
5584 assert!(
5585 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5586 "expected no authorization request for allowed .md file"
5587 );
5588}
5589
5590#[gpui::test]
5591async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5592 init_test(cx);
5593
5594 cx.update(|cx| {
5595 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5596 settings.tool_permissions.tools.insert(
5597 "fetch".into(),
5598 agent_settings::ToolRules {
5599 default_mode: settings::ToolPermissionMode::Allow,
5600 always_allow: vec![],
5601 always_deny: vec![
5602 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5603 ],
5604 always_confirm: vec![],
5605 invalid_patterns: vec![],
5606 },
5607 );
5608 agent_settings::AgentSettings::override_global(settings, cx);
5609 });
5610
5611 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5612
5613 #[allow(clippy::arc_with_non_send_sync)]
5614 let tool = Arc::new(crate::FetchTool::new(http_client));
5615 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5616
5617 let input: crate::FetchToolInput =
5618 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5619
5620 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5621
5622 let result = task.await;
5623 assert!(result.is_err(), "expected fetch to be blocked");
5624 assert!(
5625 result.unwrap_err().to_string().contains("blocked"),
5626 "error should mention the fetch was blocked"
5627 );
5628}
5629
5630#[gpui::test]
5631async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5632 init_test(cx);
5633
5634 cx.update(|cx| {
5635 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5636 settings.always_allow_tool_actions = false;
5637 settings.tool_permissions.tools.insert(
5638 "fetch".into(),
5639 agent_settings::ToolRules {
5640 default_mode: settings::ToolPermissionMode::Confirm,
5641 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5642 always_deny: vec![],
5643 always_confirm: vec![],
5644 invalid_patterns: vec![],
5645 },
5646 );
5647 agent_settings::AgentSettings::override_global(settings, cx);
5648 });
5649
5650 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5651
5652 #[allow(clippy::arc_with_non_send_sync)]
5653 let tool = Arc::new(crate::FetchTool::new(http_client));
5654 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5655
5656 let input: crate::FetchToolInput =
5657 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5658
5659 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5660
5661 cx.run_until_parked();
5662
5663 let event = rx.try_next();
5664 assert!(
5665 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5666 "expected no authorization request for allowed docs.rs URL"
5667 );
5668}
5669
5670#[gpui::test]
5671async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5672 init_test(cx);
5673 always_allow_tools(cx);
5674
5675 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5676 let fake_model = model.as_fake();
5677
5678 // Add a tool so we can simulate tool calls
5679 thread.update(cx, |thread, _cx| {
5680 thread.add_tool(EchoTool);
5681 });
5682
5683 // Start a turn by sending a message
5684 let mut events = thread
5685 .update(cx, |thread, cx| {
5686 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5687 })
5688 .unwrap();
5689 cx.run_until_parked();
5690
5691 // Simulate the model making a tool call
5692 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5693 LanguageModelToolUse {
5694 id: "tool_1".into(),
5695 name: "echo".into(),
5696 raw_input: r#"{"text": "hello"}"#.into(),
5697 input: json!({"text": "hello"}),
5698 is_input_complete: true,
5699 thought_signature: None,
5700 },
5701 ));
5702 fake_model
5703 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5704
5705 // Queue a message before ending the stream
5706 thread.update(cx, |thread, _cx| {
5707 thread.queue_message(
5708 vec![acp::ContentBlock::Text(acp::TextContent::new(
5709 "This is my queued message".to_string(),
5710 ))],
5711 vec![],
5712 );
5713 });
5714
5715 // Now end the stream - tool will run, and the boundary check should see the queue
5716 fake_model.end_last_completion_stream();
5717
5718 // Collect all events until the turn stops
5719 let all_events = collect_events_until_stop(&mut events, cx).await;
5720
5721 // Verify we received the tool call event
5722 let tool_call_ids: Vec<_> = all_events
5723 .iter()
5724 .filter_map(|e| match e {
5725 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5726 _ => None,
5727 })
5728 .collect();
5729 assert_eq!(
5730 tool_call_ids,
5731 vec!["tool_1"],
5732 "Should have received a tool call event for our echo tool"
5733 );
5734
5735 // The turn should have stopped with EndTurn
5736 let stop_reasons = stop_events(all_events);
5737 assert_eq!(
5738 stop_reasons,
5739 vec![acp::StopReason::EndTurn],
5740 "Turn should have ended after tool completion due to queued message"
5741 );
5742
5743 // Verify the queued message is still there
5744 thread.update(cx, |thread, _cx| {
5745 let queued = thread.queued_messages();
5746 assert_eq!(queued.len(), 1, "Should still have one queued message");
5747 assert!(matches!(
5748 &queued[0].content[0],
5749 acp::ContentBlock::Text(t) if t.text == "This is my queued message"
5750 ));
5751 });
5752
5753 // Thread should be idle now
5754 thread.update(cx, |thread, _cx| {
5755 assert!(
5756 thread.is_turn_complete(),
5757 "Thread should not be running after turn ends"
5758 );
5759 });
5760}