1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, UserMessageId,
4};
5use agent_client_protocol::{self as acp};
6use agent_settings::AgentProfileId;
7use anyhow::Result;
8use client::{Client, UserStore};
9use cloud_llm_client::CompletionIntent;
10use collections::IndexMap;
11use context_server::{ContextServer, ContextServerCommand, ContextServerId};
12use feature_flags::FeatureFlagAppExt as _;
13use fs::{FakeFs, Fs};
14use futures::{
15 FutureExt as _, StreamExt,
16 channel::{
17 mpsc::{self, UnboundedReceiver},
18 oneshot,
19 },
20 future::{Fuse, Shared},
21};
22use gpui::{
23 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
24 http_client::FakeHttpClient,
25};
26use indoc::indoc;
27use language_model::{
28 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
29 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
30 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
31 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
32};
33use pretty_assertions::assert_eq;
34use project::{
35 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
36};
37use prompt_store::ProjectContext;
38use reqwest_client::ReqwestClient;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use serde_json::json;
42use settings::{Settings, SettingsStore};
43use std::{
44 path::Path,
45 pin::Pin,
46 rc::Rc,
47 sync::{
48 Arc,
49 atomic::{AtomicBool, Ordering},
50 },
51 time::Duration,
52};
53use util::path;
54
55mod test_tools;
56use test_tools::*;
57
58fn init_test(cx: &mut TestAppContext) {
59 cx.update(|cx| {
60 let settings_store = SettingsStore::test(cx);
61 cx.set_global(settings_store);
62 });
63}
64
65struct FakeTerminalHandle {
66 killed: Arc<AtomicBool>,
67 stopped_by_user: Arc<AtomicBool>,
68 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
69 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
70 output: acp::TerminalOutputResponse,
71 id: acp::TerminalId,
72}
73
74impl FakeTerminalHandle {
75 fn new_never_exits(cx: &mut App) -> Self {
76 let killed = Arc::new(AtomicBool::new(false));
77 let stopped_by_user = Arc::new(AtomicBool::new(false));
78
79 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
80
81 let wait_for_exit = cx
82 .spawn(async move |_cx| {
83 // Wait for the exit signal (sent when kill() is called)
84 let _ = exit_receiver.await;
85 acp::TerminalExitStatus::new()
86 })
87 .shared();
88
89 Self {
90 killed,
91 stopped_by_user,
92 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
93 wait_for_exit,
94 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
95 id: acp::TerminalId::new("fake_terminal".to_string()),
96 }
97 }
98
99 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
100 let killed = Arc::new(AtomicBool::new(false));
101 let stopped_by_user = Arc::new(AtomicBool::new(false));
102 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
103
104 let wait_for_exit = cx
105 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
106 .shared();
107
108 Self {
109 killed,
110 stopped_by_user,
111 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
112 wait_for_exit,
113 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
114 id: acp::TerminalId::new("fake_terminal".to_string()),
115 }
116 }
117
118 fn was_killed(&self) -> bool {
119 self.killed.load(Ordering::SeqCst)
120 }
121
122 fn set_stopped_by_user(&self, stopped: bool) {
123 self.stopped_by_user.store(stopped, Ordering::SeqCst);
124 }
125
126 fn signal_exit(&self) {
127 if let Some(sender) = self.exit_sender.borrow_mut().take() {
128 let _ = sender.send(());
129 }
130 }
131}
132
133impl crate::TerminalHandle for FakeTerminalHandle {
134 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
135 Ok(self.id.clone())
136 }
137
138 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
139 Ok(self.output.clone())
140 }
141
142 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
143 Ok(self.wait_for_exit.clone())
144 }
145
146 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
147 self.killed.store(true, Ordering::SeqCst);
148 self.signal_exit();
149 Ok(())
150 }
151
152 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
153 Ok(self.stopped_by_user.load(Ordering::SeqCst))
154 }
155}
156
157struct FakeThreadEnvironment {
158 handle: Rc<FakeTerminalHandle>,
159}
160
161impl crate::ThreadEnvironment for FakeThreadEnvironment {
162 fn create_terminal(
163 &self,
164 _command: String,
165 _cwd: Option<std::path::PathBuf>,
166 _output_byte_limit: Option<u64>,
167 _cx: &mut AsyncApp,
168 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
169 Task::ready(Ok(self.handle.clone() as Rc<dyn crate::TerminalHandle>))
170 }
171}
172
173/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
174struct MultiTerminalEnvironment {
175 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
176}
177
178impl MultiTerminalEnvironment {
179 fn new() -> Self {
180 Self {
181 handles: std::cell::RefCell::new(Vec::new()),
182 }
183 }
184
185 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
186 self.handles.borrow().clone()
187 }
188}
189
190impl crate::ThreadEnvironment for MultiTerminalEnvironment {
191 fn create_terminal(
192 &self,
193 _command: String,
194 _cwd: Option<std::path::PathBuf>,
195 _output_byte_limit: Option<u64>,
196 cx: &mut AsyncApp,
197 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
198 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
199 self.handles.borrow_mut().push(handle.clone());
200 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
201 }
202}
203
204fn always_allow_tools(cx: &mut TestAppContext) {
205 cx.update(|cx| {
206 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
207 settings.always_allow_tool_actions = true;
208 agent_settings::AgentSettings::override_global(settings, cx);
209 });
210}
211
212#[gpui::test]
213async fn test_echo(cx: &mut TestAppContext) {
214 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
215 let fake_model = model.as_fake();
216
217 let events = thread
218 .update(cx, |thread, cx| {
219 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
220 })
221 .unwrap();
222 cx.run_until_parked();
223 fake_model.send_last_completion_stream_text_chunk("Hello");
224 fake_model
225 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
226 fake_model.end_last_completion_stream();
227
228 let events = events.collect().await;
229 thread.update(cx, |thread, _cx| {
230 assert_eq!(
231 thread.last_message().unwrap().to_markdown(),
232 indoc! {"
233 ## Assistant
234
235 Hello
236 "}
237 )
238 });
239 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
240}
241
242#[gpui::test]
243async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
244 init_test(cx);
245 always_allow_tools(cx);
246
247 let fs = FakeFs::new(cx.executor());
248 let project = Project::test(fs, [], cx).await;
249
250 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
251 let environment = Rc::new(FakeThreadEnvironment {
252 handle: handle.clone(),
253 });
254
255 #[allow(clippy::arc_with_non_send_sync)]
256 let tool = Arc::new(crate::TerminalTool::new(project, environment));
257 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
258
259 let task = cx.update(|cx| {
260 tool.run(
261 crate::TerminalToolInput {
262 command: "sleep 1000".to_string(),
263 cd: ".".to_string(),
264 timeout_ms: Some(5),
265 },
266 event_stream,
267 cx,
268 )
269 });
270
271 let update = rx.expect_update_fields().await;
272 assert!(
273 update.content.iter().any(|blocks| {
274 blocks
275 .iter()
276 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
277 }),
278 "expected tool call update to include terminal content"
279 );
280
281 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
282
283 let deadline = std::time::Instant::now() + Duration::from_millis(500);
284 loop {
285 if let Some(result) = task_future.as_mut().now_or_never() {
286 let result = result.expect("terminal tool task should complete");
287
288 assert!(
289 handle.was_killed(),
290 "expected terminal handle to be killed on timeout"
291 );
292 assert!(
293 result.contains("partial output"),
294 "expected result to include terminal output, got: {result}"
295 );
296 return;
297 }
298
299 if std::time::Instant::now() >= deadline {
300 panic!("timed out waiting for terminal tool task to complete");
301 }
302
303 cx.run_until_parked();
304 cx.background_executor.timer(Duration::from_millis(1)).await;
305 }
306}
307
308#[gpui::test]
309#[ignore]
310async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
311 init_test(cx);
312 always_allow_tools(cx);
313
314 let fs = FakeFs::new(cx.executor());
315 let project = Project::test(fs, [], cx).await;
316
317 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
318 let environment = Rc::new(FakeThreadEnvironment {
319 handle: handle.clone(),
320 });
321
322 #[allow(clippy::arc_with_non_send_sync)]
323 let tool = Arc::new(crate::TerminalTool::new(project, environment));
324 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
325
326 let _task = cx.update(|cx| {
327 tool.run(
328 crate::TerminalToolInput {
329 command: "sleep 1000".to_string(),
330 cd: ".".to_string(),
331 timeout_ms: None,
332 },
333 event_stream,
334 cx,
335 )
336 });
337
338 let update = rx.expect_update_fields().await;
339 assert!(
340 update.content.iter().any(|blocks| {
341 blocks
342 .iter()
343 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
344 }),
345 "expected tool call update to include terminal content"
346 );
347
348 cx.background_executor
349 .timer(Duration::from_millis(25))
350 .await;
351
352 assert!(
353 !handle.was_killed(),
354 "did not expect terminal handle to be killed without a timeout"
355 );
356}
357
358#[gpui::test]
359async fn test_thinking(cx: &mut TestAppContext) {
360 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
361 let fake_model = model.as_fake();
362
363 let events = thread
364 .update(cx, |thread, cx| {
365 thread.send(
366 UserMessageId::new(),
367 [indoc! {"
368 Testing:
369
370 Generate a thinking step where you just think the word 'Think',
371 and have your final answer be 'Hello'
372 "}],
373 cx,
374 )
375 })
376 .unwrap();
377 cx.run_until_parked();
378 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
379 text: "Think".to_string(),
380 signature: None,
381 });
382 fake_model.send_last_completion_stream_text_chunk("Hello");
383 fake_model
384 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
385 fake_model.end_last_completion_stream();
386
387 let events = events.collect().await;
388 thread.update(cx, |thread, _cx| {
389 assert_eq!(
390 thread.last_message().unwrap().to_markdown(),
391 indoc! {"
392 ## Assistant
393
394 <think>Think</think>
395 Hello
396 "}
397 )
398 });
399 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
400}
401
402#[gpui::test]
403async fn test_system_prompt(cx: &mut TestAppContext) {
404 let ThreadTest {
405 model,
406 thread,
407 project_context,
408 ..
409 } = setup(cx, TestModel::Fake).await;
410 let fake_model = model.as_fake();
411
412 project_context.update(cx, |project_context, _cx| {
413 project_context.shell = "test-shell".into()
414 });
415 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
416 thread
417 .update(cx, |thread, cx| {
418 thread.send(UserMessageId::new(), ["abc"], cx)
419 })
420 .unwrap();
421 cx.run_until_parked();
422 let mut pending_completions = fake_model.pending_completions();
423 assert_eq!(
424 pending_completions.len(),
425 1,
426 "unexpected pending completions: {:?}",
427 pending_completions
428 );
429
430 let pending_completion = pending_completions.pop().unwrap();
431 assert_eq!(pending_completion.messages[0].role, Role::System);
432
433 let system_message = &pending_completion.messages[0];
434 let system_prompt = system_message.content[0].to_str().unwrap();
435 assert!(
436 system_prompt.contains("test-shell"),
437 "unexpected system message: {:?}",
438 system_message
439 );
440 assert!(
441 system_prompt.contains("## Fixing Diagnostics"),
442 "unexpected system message: {:?}",
443 system_message
444 );
445}
446
447#[gpui::test]
448async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
449 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
450 let fake_model = model.as_fake();
451
452 thread
453 .update(cx, |thread, cx| {
454 thread.send(UserMessageId::new(), ["abc"], cx)
455 })
456 .unwrap();
457 cx.run_until_parked();
458 let mut pending_completions = fake_model.pending_completions();
459 assert_eq!(
460 pending_completions.len(),
461 1,
462 "unexpected pending completions: {:?}",
463 pending_completions
464 );
465
466 let pending_completion = pending_completions.pop().unwrap();
467 assert_eq!(pending_completion.messages[0].role, Role::System);
468
469 let system_message = &pending_completion.messages[0];
470 let system_prompt = system_message.content[0].to_str().unwrap();
471 assert!(
472 !system_prompt.contains("## Tool Use"),
473 "unexpected system message: {:?}",
474 system_message
475 );
476 assert!(
477 !system_prompt.contains("## Fixing Diagnostics"),
478 "unexpected system message: {:?}",
479 system_message
480 );
481}
482
483#[gpui::test]
484async fn test_prompt_caching(cx: &mut TestAppContext) {
485 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
486 let fake_model = model.as_fake();
487
488 // Send initial user message and verify it's cached
489 thread
490 .update(cx, |thread, cx| {
491 thread.send(UserMessageId::new(), ["Message 1"], cx)
492 })
493 .unwrap();
494 cx.run_until_parked();
495
496 let completion = fake_model.pending_completions().pop().unwrap();
497 assert_eq!(
498 completion.messages[1..],
499 vec![LanguageModelRequestMessage {
500 role: Role::User,
501 content: vec!["Message 1".into()],
502 cache: true,
503 reasoning_details: None,
504 }]
505 );
506 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
507 "Response to Message 1".into(),
508 ));
509 fake_model.end_last_completion_stream();
510 cx.run_until_parked();
511
512 // Send another user message and verify only the latest is cached
513 thread
514 .update(cx, |thread, cx| {
515 thread.send(UserMessageId::new(), ["Message 2"], cx)
516 })
517 .unwrap();
518 cx.run_until_parked();
519
520 let completion = fake_model.pending_completions().pop().unwrap();
521 assert_eq!(
522 completion.messages[1..],
523 vec![
524 LanguageModelRequestMessage {
525 role: Role::User,
526 content: vec!["Message 1".into()],
527 cache: false,
528 reasoning_details: None,
529 },
530 LanguageModelRequestMessage {
531 role: Role::Assistant,
532 content: vec!["Response to Message 1".into()],
533 cache: false,
534 reasoning_details: None,
535 },
536 LanguageModelRequestMessage {
537 role: Role::User,
538 content: vec!["Message 2".into()],
539 cache: true,
540 reasoning_details: None,
541 }
542 ]
543 );
544 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
545 "Response to Message 2".into(),
546 ));
547 fake_model.end_last_completion_stream();
548 cx.run_until_parked();
549
550 // Simulate a tool call and verify that the latest tool result is cached
551 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
552 thread
553 .update(cx, |thread, cx| {
554 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
555 })
556 .unwrap();
557 cx.run_until_parked();
558
559 let tool_use = LanguageModelToolUse {
560 id: "tool_1".into(),
561 name: EchoTool::name().into(),
562 raw_input: json!({"text": "test"}).to_string(),
563 input: json!({"text": "test"}),
564 is_input_complete: true,
565 thought_signature: None,
566 };
567 fake_model
568 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
569 fake_model.end_last_completion_stream();
570 cx.run_until_parked();
571
572 let completion = fake_model.pending_completions().pop().unwrap();
573 let tool_result = LanguageModelToolResult {
574 tool_use_id: "tool_1".into(),
575 tool_name: EchoTool::name().into(),
576 is_error: false,
577 content: "test".into(),
578 output: Some("test".into()),
579 };
580 assert_eq!(
581 completion.messages[1..],
582 vec![
583 LanguageModelRequestMessage {
584 role: Role::User,
585 content: vec!["Message 1".into()],
586 cache: false,
587 reasoning_details: None,
588 },
589 LanguageModelRequestMessage {
590 role: Role::Assistant,
591 content: vec!["Response to Message 1".into()],
592 cache: false,
593 reasoning_details: None,
594 },
595 LanguageModelRequestMessage {
596 role: Role::User,
597 content: vec!["Message 2".into()],
598 cache: false,
599 reasoning_details: None,
600 },
601 LanguageModelRequestMessage {
602 role: Role::Assistant,
603 content: vec!["Response to Message 2".into()],
604 cache: false,
605 reasoning_details: None,
606 },
607 LanguageModelRequestMessage {
608 role: Role::User,
609 content: vec!["Use the echo tool".into()],
610 cache: false,
611 reasoning_details: None,
612 },
613 LanguageModelRequestMessage {
614 role: Role::Assistant,
615 content: vec![MessageContent::ToolUse(tool_use)],
616 cache: false,
617 reasoning_details: None,
618 },
619 LanguageModelRequestMessage {
620 role: Role::User,
621 content: vec![MessageContent::ToolResult(tool_result)],
622 cache: true,
623 reasoning_details: None,
624 }
625 ]
626 );
627}
628
629#[gpui::test]
630#[cfg_attr(not(feature = "e2e"), ignore)]
631async fn test_basic_tool_calls(cx: &mut TestAppContext) {
632 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
633
634 // Test a tool call that's likely to complete *before* streaming stops.
635 let events = thread
636 .update(cx, |thread, cx| {
637 thread.add_tool(EchoTool);
638 thread.send(
639 UserMessageId::new(),
640 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
641 cx,
642 )
643 })
644 .unwrap()
645 .collect()
646 .await;
647 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
648
649 // Test a tool calls that's likely to complete *after* streaming stops.
650 let events = thread
651 .update(cx, |thread, cx| {
652 thread.remove_tool(&EchoTool::name());
653 thread.add_tool(DelayTool);
654 thread.send(
655 UserMessageId::new(),
656 [
657 "Now call the delay tool with 200ms.",
658 "When the timer goes off, then you echo the output of the tool.",
659 ],
660 cx,
661 )
662 })
663 .unwrap()
664 .collect()
665 .await;
666 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
667 thread.update(cx, |thread, _cx| {
668 assert!(
669 thread
670 .last_message()
671 .unwrap()
672 .as_agent_message()
673 .unwrap()
674 .content
675 .iter()
676 .any(|content| {
677 if let AgentMessageContent::Text(text) = content {
678 text.contains("Ding")
679 } else {
680 false
681 }
682 }),
683 "{}",
684 thread.to_markdown()
685 );
686 });
687}
688
689#[gpui::test]
690#[cfg_attr(not(feature = "e2e"), ignore)]
691async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
692 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
693
694 // Test a tool call that's likely to complete *before* streaming stops.
695 let mut events = thread
696 .update(cx, |thread, cx| {
697 thread.add_tool(WordListTool);
698 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
699 })
700 .unwrap();
701
702 let mut saw_partial_tool_use = false;
703 while let Some(event) = events.next().await {
704 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
705 thread.update(cx, |thread, _cx| {
706 // Look for a tool use in the thread's last message
707 let message = thread.last_message().unwrap();
708 let agent_message = message.as_agent_message().unwrap();
709 let last_content = agent_message.content.last().unwrap();
710 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
711 assert_eq!(last_tool_use.name.as_ref(), "word_list");
712 if tool_call.status == acp::ToolCallStatus::Pending {
713 if !last_tool_use.is_input_complete
714 && last_tool_use.input.get("g").is_none()
715 {
716 saw_partial_tool_use = true;
717 }
718 } else {
719 last_tool_use
720 .input
721 .get("a")
722 .expect("'a' has streamed because input is now complete");
723 last_tool_use
724 .input
725 .get("g")
726 .expect("'g' has streamed because input is now complete");
727 }
728 } else {
729 panic!("last content should be a tool use");
730 }
731 });
732 }
733 }
734
735 assert!(
736 saw_partial_tool_use,
737 "should see at least one partially streamed tool use in the history"
738 );
739}
740
741#[gpui::test]
742async fn test_tool_authorization(cx: &mut TestAppContext) {
743 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
744 let fake_model = model.as_fake();
745
746 let mut events = thread
747 .update(cx, |thread, cx| {
748 thread.add_tool(ToolRequiringPermission);
749 thread.send(UserMessageId::new(), ["abc"], cx)
750 })
751 .unwrap();
752 cx.run_until_parked();
753 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
754 LanguageModelToolUse {
755 id: "tool_id_1".into(),
756 name: ToolRequiringPermission::name().into(),
757 raw_input: "{}".into(),
758 input: json!({}),
759 is_input_complete: true,
760 thought_signature: None,
761 },
762 ));
763 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
764 LanguageModelToolUse {
765 id: "tool_id_2".into(),
766 name: ToolRequiringPermission::name().into(),
767 raw_input: "{}".into(),
768 input: json!({}),
769 is_input_complete: true,
770 thought_signature: None,
771 },
772 ));
773 fake_model.end_last_completion_stream();
774 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
775 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
776
777 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
778 tool_call_auth_1
779 .response
780 .send(acp::PermissionOptionId::new("allow"))
781 .unwrap();
782 cx.run_until_parked();
783
784 // Reject the second - send "deny" option_id directly since Deny is now a button
785 tool_call_auth_2
786 .response
787 .send(acp::PermissionOptionId::new("deny"))
788 .unwrap();
789 cx.run_until_parked();
790
791 let completion = fake_model.pending_completions().pop().unwrap();
792 let message = completion.messages.last().unwrap();
793 assert_eq!(
794 message.content,
795 vec![
796 language_model::MessageContent::ToolResult(LanguageModelToolResult {
797 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
798 tool_name: ToolRequiringPermission::name().into(),
799 is_error: false,
800 content: "Allowed".into(),
801 output: Some("Allowed".into())
802 }),
803 language_model::MessageContent::ToolResult(LanguageModelToolResult {
804 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
805 tool_name: ToolRequiringPermission::name().into(),
806 is_error: true,
807 content: "Permission to run tool denied by user".into(),
808 output: Some("Permission to run tool denied by user".into())
809 })
810 ]
811 );
812
813 // Simulate yet another tool call.
814 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
815 LanguageModelToolUse {
816 id: "tool_id_3".into(),
817 name: ToolRequiringPermission::name().into(),
818 raw_input: "{}".into(),
819 input: json!({}),
820 is_input_complete: true,
821 thought_signature: None,
822 },
823 ));
824 fake_model.end_last_completion_stream();
825
826 // Respond by always allowing tools - send transformed option_id
827 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
828 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
829 tool_call_auth_3
830 .response
831 .send(acp::PermissionOptionId::new(
832 "always_allow:tool_requiring_permission",
833 ))
834 .unwrap();
835 cx.run_until_parked();
836 let completion = fake_model.pending_completions().pop().unwrap();
837 let message = completion.messages.last().unwrap();
838 assert_eq!(
839 message.content,
840 vec![language_model::MessageContent::ToolResult(
841 LanguageModelToolResult {
842 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
843 tool_name: ToolRequiringPermission::name().into(),
844 is_error: false,
845 content: "Allowed".into(),
846 output: Some("Allowed".into())
847 }
848 )]
849 );
850
851 // Simulate a final tool call, ensuring we don't trigger authorization.
852 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
853 LanguageModelToolUse {
854 id: "tool_id_4".into(),
855 name: ToolRequiringPermission::name().into(),
856 raw_input: "{}".into(),
857 input: json!({}),
858 is_input_complete: true,
859 thought_signature: None,
860 },
861 ));
862 fake_model.end_last_completion_stream();
863 cx.run_until_parked();
864 let completion = fake_model.pending_completions().pop().unwrap();
865 let message = completion.messages.last().unwrap();
866 assert_eq!(
867 message.content,
868 vec![language_model::MessageContent::ToolResult(
869 LanguageModelToolResult {
870 tool_use_id: "tool_id_4".into(),
871 tool_name: ToolRequiringPermission::name().into(),
872 is_error: false,
873 content: "Allowed".into(),
874 output: Some("Allowed".into())
875 }
876 )]
877 );
878}
879
880#[gpui::test]
881async fn test_tool_hallucination(cx: &mut TestAppContext) {
882 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
883 let fake_model = model.as_fake();
884
885 let mut events = thread
886 .update(cx, |thread, cx| {
887 thread.send(UserMessageId::new(), ["abc"], cx)
888 })
889 .unwrap();
890 cx.run_until_parked();
891 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
892 LanguageModelToolUse {
893 id: "tool_id_1".into(),
894 name: "nonexistent_tool".into(),
895 raw_input: "{}".into(),
896 input: json!({}),
897 is_input_complete: true,
898 thought_signature: None,
899 },
900 ));
901 fake_model.end_last_completion_stream();
902
903 let tool_call = expect_tool_call(&mut events).await;
904 assert_eq!(tool_call.title, "nonexistent_tool");
905 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
906 let update = expect_tool_call_update_fields(&mut events).await;
907 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
908}
909
910async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
911 let event = events
912 .next()
913 .await
914 .expect("no tool call authorization event received")
915 .unwrap();
916 match event {
917 ThreadEvent::ToolCall(tool_call) => tool_call,
918 event => {
919 panic!("Unexpected event {event:?}");
920 }
921 }
922}
923
924async fn expect_tool_call_update_fields(
925 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
926) -> acp::ToolCallUpdate {
927 let event = events
928 .next()
929 .await
930 .expect("no tool call authorization event received")
931 .unwrap();
932 match event {
933 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
934 event => {
935 panic!("Unexpected event {event:?}");
936 }
937 }
938}
939
940async fn next_tool_call_authorization(
941 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
942) -> ToolCallAuthorization {
943 loop {
944 let event = events
945 .next()
946 .await
947 .expect("no tool call authorization event received")
948 .unwrap();
949 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
950 let permission_kinds = tool_call_authorization
951 .options
952 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
953 .map(|option| option.kind);
954 let allow_once = tool_call_authorization
955 .options
956 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
957 .map(|option| option.kind);
958
959 assert_eq!(
960 permission_kinds,
961 Some(acp::PermissionOptionKind::AllowAlways)
962 );
963 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
964 return tool_call_authorization;
965 }
966 }
967}
968
969#[test]
970fn test_permission_options_terminal_with_pattern() {
971 let permission_options =
972 ToolPermissionContext::new("terminal", "cargo build --release").build_permission_options();
973
974 let PermissionOptions::Dropdown(choices) = permission_options else {
975 panic!("Expected dropdown permission options");
976 };
977
978 assert_eq!(choices.len(), 3);
979 let labels: Vec<&str> = choices
980 .iter()
981 .map(|choice| choice.allow.name.as_ref())
982 .collect();
983 assert!(labels.contains(&"Always for terminal"));
984 assert!(labels.contains(&"Always for `cargo` commands"));
985 assert!(labels.contains(&"Only this time"));
986}
987
988#[test]
989fn test_permission_options_edit_file_with_path_pattern() {
990 let permission_options =
991 ToolPermissionContext::new("edit_file", "src/main.rs").build_permission_options();
992
993 let PermissionOptions::Dropdown(choices) = permission_options else {
994 panic!("Expected dropdown permission options");
995 };
996
997 let labels: Vec<&str> = choices
998 .iter()
999 .map(|choice| choice.allow.name.as_ref())
1000 .collect();
1001 assert!(labels.contains(&"Always for edit file"));
1002 assert!(labels.contains(&"Always for `src/`"));
1003}
1004
1005#[test]
1006fn test_permission_options_fetch_with_domain_pattern() {
1007 let permission_options =
1008 ToolPermissionContext::new("fetch", "https://docs.rs/gpui").build_permission_options();
1009
1010 let PermissionOptions::Dropdown(choices) = permission_options else {
1011 panic!("Expected dropdown permission options");
1012 };
1013
1014 let labels: Vec<&str> = choices
1015 .iter()
1016 .map(|choice| choice.allow.name.as_ref())
1017 .collect();
1018 assert!(labels.contains(&"Always for fetch"));
1019 assert!(labels.contains(&"Always for `docs.rs`"));
1020}
1021
1022#[test]
1023fn test_permission_options_without_pattern() {
1024 let permission_options = ToolPermissionContext::new("terminal", "./deploy.sh --production")
1025 .build_permission_options();
1026
1027 let PermissionOptions::Dropdown(choices) = permission_options else {
1028 panic!("Expected dropdown permission options");
1029 };
1030
1031 assert_eq!(choices.len(), 2);
1032 let labels: Vec<&str> = choices
1033 .iter()
1034 .map(|choice| choice.allow.name.as_ref())
1035 .collect();
1036 assert!(labels.contains(&"Always for terminal"));
1037 assert!(labels.contains(&"Only this time"));
1038 assert!(!labels.iter().any(|label| label.contains("commands")));
1039}
1040
1041#[test]
1042fn test_permission_option_ids_for_terminal() {
1043 let permission_options =
1044 ToolPermissionContext::new("terminal", "cargo build --release").build_permission_options();
1045
1046 let PermissionOptions::Dropdown(choices) = permission_options else {
1047 panic!("Expected dropdown permission options");
1048 };
1049
1050 let allow_ids: Vec<String> = choices
1051 .iter()
1052 .map(|choice| choice.allow.option_id.0.to_string())
1053 .collect();
1054 let deny_ids: Vec<String> = choices
1055 .iter()
1056 .map(|choice| choice.deny.option_id.0.to_string())
1057 .collect();
1058
1059 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1060 assert!(allow_ids.contains(&"allow".to_string()));
1061 assert!(
1062 allow_ids
1063 .iter()
1064 .any(|id| id.starts_with("always_allow_pattern:terminal:")),
1065 "Missing allow pattern option"
1066 );
1067
1068 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1069 assert!(deny_ids.contains(&"deny".to_string()));
1070 assert!(
1071 deny_ids
1072 .iter()
1073 .any(|id| id.starts_with("always_deny_pattern:terminal:")),
1074 "Missing deny pattern option"
1075 );
1076}
1077
1078#[gpui::test]
1079#[cfg_attr(not(feature = "e2e"), ignore)]
1080async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1081 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1082
1083 // Test concurrent tool calls with different delay times
1084 let events = thread
1085 .update(cx, |thread, cx| {
1086 thread.add_tool(DelayTool);
1087 thread.send(
1088 UserMessageId::new(),
1089 [
1090 "Call the delay tool twice in the same message.",
1091 "Once with 100ms. Once with 300ms.",
1092 "When both timers are complete, describe the outputs.",
1093 ],
1094 cx,
1095 )
1096 })
1097 .unwrap()
1098 .collect()
1099 .await;
1100
1101 let stop_reasons = stop_events(events);
1102 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1103
1104 thread.update(cx, |thread, _cx| {
1105 let last_message = thread.last_message().unwrap();
1106 let agent_message = last_message.as_agent_message().unwrap();
1107 let text = agent_message
1108 .content
1109 .iter()
1110 .filter_map(|content| {
1111 if let AgentMessageContent::Text(text) = content {
1112 Some(text.as_str())
1113 } else {
1114 None
1115 }
1116 })
1117 .collect::<String>();
1118
1119 assert!(text.contains("Ding"));
1120 });
1121}
1122
1123#[gpui::test]
1124async fn test_profiles(cx: &mut TestAppContext) {
1125 let ThreadTest {
1126 model, thread, fs, ..
1127 } = setup(cx, TestModel::Fake).await;
1128 let fake_model = model.as_fake();
1129
1130 thread.update(cx, |thread, _cx| {
1131 thread.add_tool(DelayTool);
1132 thread.add_tool(EchoTool);
1133 thread.add_tool(InfiniteTool);
1134 });
1135
1136 // Override profiles and wait for settings to be loaded.
1137 fs.insert_file(
1138 paths::settings_file(),
1139 json!({
1140 "agent": {
1141 "profiles": {
1142 "test-1": {
1143 "name": "Test Profile 1",
1144 "tools": {
1145 EchoTool::name(): true,
1146 DelayTool::name(): true,
1147 }
1148 },
1149 "test-2": {
1150 "name": "Test Profile 2",
1151 "tools": {
1152 InfiniteTool::name(): true,
1153 }
1154 }
1155 }
1156 }
1157 })
1158 .to_string()
1159 .into_bytes(),
1160 )
1161 .await;
1162 cx.run_until_parked();
1163
1164 // Test that test-1 profile (default) has echo and delay tools
1165 thread
1166 .update(cx, |thread, cx| {
1167 thread.set_profile(AgentProfileId("test-1".into()), cx);
1168 thread.send(UserMessageId::new(), ["test"], cx)
1169 })
1170 .unwrap();
1171 cx.run_until_parked();
1172
1173 let mut pending_completions = fake_model.pending_completions();
1174 assert_eq!(pending_completions.len(), 1);
1175 let completion = pending_completions.pop().unwrap();
1176 let tool_names: Vec<String> = completion
1177 .tools
1178 .iter()
1179 .map(|tool| tool.name.clone())
1180 .collect();
1181 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
1182 fake_model.end_last_completion_stream();
1183
1184 // Switch to test-2 profile, and verify that it has only the infinite tool.
1185 thread
1186 .update(cx, |thread, cx| {
1187 thread.set_profile(AgentProfileId("test-2".into()), cx);
1188 thread.send(UserMessageId::new(), ["test2"], cx)
1189 })
1190 .unwrap();
1191 cx.run_until_parked();
1192 let mut pending_completions = fake_model.pending_completions();
1193 assert_eq!(pending_completions.len(), 1);
1194 let completion = pending_completions.pop().unwrap();
1195 let tool_names: Vec<String> = completion
1196 .tools
1197 .iter()
1198 .map(|tool| tool.name.clone())
1199 .collect();
1200 assert_eq!(tool_names, vec![InfiniteTool::name()]);
1201}
1202
1203#[gpui::test]
1204async fn test_mcp_tools(cx: &mut TestAppContext) {
1205 let ThreadTest {
1206 model,
1207 thread,
1208 context_server_store,
1209 fs,
1210 ..
1211 } = setup(cx, TestModel::Fake).await;
1212 let fake_model = model.as_fake();
1213
1214 // Override profiles and wait for settings to be loaded.
1215 fs.insert_file(
1216 paths::settings_file(),
1217 json!({
1218 "agent": {
1219 "always_allow_tool_actions": true,
1220 "profiles": {
1221 "test": {
1222 "name": "Test Profile",
1223 "enable_all_context_servers": true,
1224 "tools": {
1225 EchoTool::name(): true,
1226 }
1227 },
1228 }
1229 }
1230 })
1231 .to_string()
1232 .into_bytes(),
1233 )
1234 .await;
1235 cx.run_until_parked();
1236 thread.update(cx, |thread, cx| {
1237 thread.set_profile(AgentProfileId("test".into()), cx)
1238 });
1239
1240 let mut mcp_tool_calls = setup_context_server(
1241 "test_server",
1242 vec![context_server::types::Tool {
1243 name: "echo".into(),
1244 description: None,
1245 input_schema: serde_json::to_value(EchoTool::input_schema(
1246 LanguageModelToolSchemaFormat::JsonSchema,
1247 ))
1248 .unwrap(),
1249 output_schema: None,
1250 annotations: None,
1251 }],
1252 &context_server_store,
1253 cx,
1254 );
1255
1256 let events = thread.update(cx, |thread, cx| {
1257 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1258 });
1259 cx.run_until_parked();
1260
1261 // Simulate the model calling the MCP tool.
1262 let completion = fake_model.pending_completions().pop().unwrap();
1263 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1264 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1265 LanguageModelToolUse {
1266 id: "tool_1".into(),
1267 name: "echo".into(),
1268 raw_input: json!({"text": "test"}).to_string(),
1269 input: json!({"text": "test"}),
1270 is_input_complete: true,
1271 thought_signature: None,
1272 },
1273 ));
1274 fake_model.end_last_completion_stream();
1275 cx.run_until_parked();
1276
1277 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1278 assert_eq!(tool_call_params.name, "echo");
1279 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1280 tool_call_response
1281 .send(context_server::types::CallToolResponse {
1282 content: vec![context_server::types::ToolResponseContent::Text {
1283 text: "test".into(),
1284 }],
1285 is_error: None,
1286 meta: None,
1287 structured_content: None,
1288 })
1289 .unwrap();
1290 cx.run_until_parked();
1291
1292 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1293 fake_model.send_last_completion_stream_text_chunk("Done!");
1294 fake_model.end_last_completion_stream();
1295 events.collect::<Vec<_>>().await;
1296
1297 // Send again after adding the echo tool, ensuring the name collision is resolved.
1298 let events = thread.update(cx, |thread, cx| {
1299 thread.add_tool(EchoTool);
1300 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1301 });
1302 cx.run_until_parked();
1303 let completion = fake_model.pending_completions().pop().unwrap();
1304 assert_eq!(
1305 tool_names_for_completion(&completion),
1306 vec!["echo", "test_server_echo"]
1307 );
1308 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1309 LanguageModelToolUse {
1310 id: "tool_2".into(),
1311 name: "test_server_echo".into(),
1312 raw_input: json!({"text": "mcp"}).to_string(),
1313 input: json!({"text": "mcp"}),
1314 is_input_complete: true,
1315 thought_signature: None,
1316 },
1317 ));
1318 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1319 LanguageModelToolUse {
1320 id: "tool_3".into(),
1321 name: "echo".into(),
1322 raw_input: json!({"text": "native"}).to_string(),
1323 input: json!({"text": "native"}),
1324 is_input_complete: true,
1325 thought_signature: None,
1326 },
1327 ));
1328 fake_model.end_last_completion_stream();
1329 cx.run_until_parked();
1330
1331 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1332 assert_eq!(tool_call_params.name, "echo");
1333 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1334 tool_call_response
1335 .send(context_server::types::CallToolResponse {
1336 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1337 is_error: None,
1338 meta: None,
1339 structured_content: None,
1340 })
1341 .unwrap();
1342 cx.run_until_parked();
1343
1344 // Ensure the tool results were inserted with the correct names.
1345 let completion = fake_model.pending_completions().pop().unwrap();
1346 assert_eq!(
1347 completion.messages.last().unwrap().content,
1348 vec![
1349 MessageContent::ToolResult(LanguageModelToolResult {
1350 tool_use_id: "tool_3".into(),
1351 tool_name: "echo".into(),
1352 is_error: false,
1353 content: "native".into(),
1354 output: Some("native".into()),
1355 },),
1356 MessageContent::ToolResult(LanguageModelToolResult {
1357 tool_use_id: "tool_2".into(),
1358 tool_name: "test_server_echo".into(),
1359 is_error: false,
1360 content: "mcp".into(),
1361 output: Some("mcp".into()),
1362 },),
1363 ]
1364 );
1365 fake_model.end_last_completion_stream();
1366 events.collect::<Vec<_>>().await;
1367}
1368
1369#[gpui::test]
1370async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1371 let ThreadTest {
1372 model,
1373 thread,
1374 context_server_store,
1375 fs,
1376 ..
1377 } = setup(cx, TestModel::Fake).await;
1378 let fake_model = model.as_fake();
1379
1380 // Set up a profile with all tools enabled
1381 fs.insert_file(
1382 paths::settings_file(),
1383 json!({
1384 "agent": {
1385 "profiles": {
1386 "test": {
1387 "name": "Test Profile",
1388 "enable_all_context_servers": true,
1389 "tools": {
1390 EchoTool::name(): true,
1391 DelayTool::name(): true,
1392 WordListTool::name(): true,
1393 ToolRequiringPermission::name(): true,
1394 InfiniteTool::name(): true,
1395 }
1396 },
1397 }
1398 }
1399 })
1400 .to_string()
1401 .into_bytes(),
1402 )
1403 .await;
1404 cx.run_until_parked();
1405
1406 thread.update(cx, |thread, cx| {
1407 thread.set_profile(AgentProfileId("test".into()), cx);
1408 thread.add_tool(EchoTool);
1409 thread.add_tool(DelayTool);
1410 thread.add_tool(WordListTool);
1411 thread.add_tool(ToolRequiringPermission);
1412 thread.add_tool(InfiniteTool);
1413 });
1414
1415 // Set up multiple context servers with some overlapping tool names
1416 let _server1_calls = setup_context_server(
1417 "xxx",
1418 vec![
1419 context_server::types::Tool {
1420 name: "echo".into(), // Conflicts with native EchoTool
1421 description: None,
1422 input_schema: serde_json::to_value(EchoTool::input_schema(
1423 LanguageModelToolSchemaFormat::JsonSchema,
1424 ))
1425 .unwrap(),
1426 output_schema: None,
1427 annotations: None,
1428 },
1429 context_server::types::Tool {
1430 name: "unique_tool_1".into(),
1431 description: None,
1432 input_schema: json!({"type": "object", "properties": {}}),
1433 output_schema: None,
1434 annotations: None,
1435 },
1436 ],
1437 &context_server_store,
1438 cx,
1439 );
1440
1441 let _server2_calls = setup_context_server(
1442 "yyy",
1443 vec![
1444 context_server::types::Tool {
1445 name: "echo".into(), // Also conflicts with native EchoTool
1446 description: None,
1447 input_schema: serde_json::to_value(EchoTool::input_schema(
1448 LanguageModelToolSchemaFormat::JsonSchema,
1449 ))
1450 .unwrap(),
1451 output_schema: None,
1452 annotations: None,
1453 },
1454 context_server::types::Tool {
1455 name: "unique_tool_2".into(),
1456 description: None,
1457 input_schema: json!({"type": "object", "properties": {}}),
1458 output_schema: None,
1459 annotations: None,
1460 },
1461 context_server::types::Tool {
1462 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1463 description: None,
1464 input_schema: json!({"type": "object", "properties": {}}),
1465 output_schema: None,
1466 annotations: None,
1467 },
1468 context_server::types::Tool {
1469 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1470 description: None,
1471 input_schema: json!({"type": "object", "properties": {}}),
1472 output_schema: None,
1473 annotations: None,
1474 },
1475 ],
1476 &context_server_store,
1477 cx,
1478 );
1479 let _server3_calls = setup_context_server(
1480 "zzz",
1481 vec![
1482 context_server::types::Tool {
1483 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1484 description: None,
1485 input_schema: json!({"type": "object", "properties": {}}),
1486 output_schema: None,
1487 annotations: None,
1488 },
1489 context_server::types::Tool {
1490 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1491 description: None,
1492 input_schema: json!({"type": "object", "properties": {}}),
1493 output_schema: None,
1494 annotations: None,
1495 },
1496 context_server::types::Tool {
1497 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1498 description: None,
1499 input_schema: json!({"type": "object", "properties": {}}),
1500 output_schema: None,
1501 annotations: None,
1502 },
1503 ],
1504 &context_server_store,
1505 cx,
1506 );
1507
1508 thread
1509 .update(cx, |thread, cx| {
1510 thread.send(UserMessageId::new(), ["Go"], cx)
1511 })
1512 .unwrap();
1513 cx.run_until_parked();
1514 let completion = fake_model.pending_completions().pop().unwrap();
1515 assert_eq!(
1516 tool_names_for_completion(&completion),
1517 vec![
1518 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1519 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1520 "delay",
1521 "echo",
1522 "infinite",
1523 "tool_requiring_permission",
1524 "unique_tool_1",
1525 "unique_tool_2",
1526 "word_list",
1527 "xxx_echo",
1528 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1529 "yyy_echo",
1530 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1531 ]
1532 );
1533}
1534
1535#[gpui::test]
1536#[cfg_attr(not(feature = "e2e"), ignore)]
1537async fn test_cancellation(cx: &mut TestAppContext) {
1538 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1539
1540 let mut events = thread
1541 .update(cx, |thread, cx| {
1542 thread.add_tool(InfiniteTool);
1543 thread.add_tool(EchoTool);
1544 thread.send(
1545 UserMessageId::new(),
1546 ["Call the echo tool, then call the infinite tool, then explain their output"],
1547 cx,
1548 )
1549 })
1550 .unwrap();
1551
1552 // Wait until both tools are called.
1553 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1554 let mut echo_id = None;
1555 let mut echo_completed = false;
1556 while let Some(event) = events.next().await {
1557 match event.unwrap() {
1558 ThreadEvent::ToolCall(tool_call) => {
1559 assert_eq!(tool_call.title, expected_tools.remove(0));
1560 if tool_call.title == "Echo" {
1561 echo_id = Some(tool_call.tool_call_id);
1562 }
1563 }
1564 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1565 acp::ToolCallUpdate {
1566 tool_call_id,
1567 fields:
1568 acp::ToolCallUpdateFields {
1569 status: Some(acp::ToolCallStatus::Completed),
1570 ..
1571 },
1572 ..
1573 },
1574 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1575 echo_completed = true;
1576 }
1577 _ => {}
1578 }
1579
1580 if expected_tools.is_empty() && echo_completed {
1581 break;
1582 }
1583 }
1584
1585 // Cancel the current send and ensure that the event stream is closed, even
1586 // if one of the tools is still running.
1587 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1588 let events = events.collect::<Vec<_>>().await;
1589 let last_event = events.last();
1590 assert!(
1591 matches!(
1592 last_event,
1593 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1594 ),
1595 "unexpected event {last_event:?}"
1596 );
1597
1598 // Ensure we can still send a new message after cancellation.
1599 let events = thread
1600 .update(cx, |thread, cx| {
1601 thread.send(
1602 UserMessageId::new(),
1603 ["Testing: reply with 'Hello' then stop."],
1604 cx,
1605 )
1606 })
1607 .unwrap()
1608 .collect::<Vec<_>>()
1609 .await;
1610 thread.update(cx, |thread, _cx| {
1611 let message = thread.last_message().unwrap();
1612 let agent_message = message.as_agent_message().unwrap();
1613 assert_eq!(
1614 agent_message.content,
1615 vec![AgentMessageContent::Text("Hello".to_string())]
1616 );
1617 });
1618 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1619}
1620
1621#[gpui::test]
1622async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1623 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1624 always_allow_tools(cx);
1625 let fake_model = model.as_fake();
1626
1627 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1628 let environment = Rc::new(FakeThreadEnvironment {
1629 handle: handle.clone(),
1630 });
1631
1632 let mut events = thread
1633 .update(cx, |thread, cx| {
1634 thread.add_tool(crate::TerminalTool::new(
1635 thread.project().clone(),
1636 environment,
1637 ));
1638 thread.send(UserMessageId::new(), ["run a command"], cx)
1639 })
1640 .unwrap();
1641
1642 cx.run_until_parked();
1643
1644 // Simulate the model calling the terminal tool
1645 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1646 LanguageModelToolUse {
1647 id: "terminal_tool_1".into(),
1648 name: "terminal".into(),
1649 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1650 input: json!({"command": "sleep 1000", "cd": "."}),
1651 is_input_complete: true,
1652 thought_signature: None,
1653 },
1654 ));
1655 fake_model.end_last_completion_stream();
1656
1657 // Wait for the terminal tool to start running
1658 wait_for_terminal_tool_started(&mut events, cx).await;
1659
1660 // Cancel the thread while the terminal is running
1661 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1662
1663 // Collect remaining events, driving the executor to let cancellation complete
1664 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1665
1666 // Verify the terminal was killed
1667 assert!(
1668 handle.was_killed(),
1669 "expected terminal handle to be killed on cancellation"
1670 );
1671
1672 // Verify we got a cancellation stop event
1673 assert_eq!(
1674 stop_events(remaining_events),
1675 vec![acp::StopReason::Cancelled],
1676 );
1677
1678 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1679 thread.update(cx, |thread, _cx| {
1680 let message = thread.last_message().unwrap();
1681 let agent_message = message.as_agent_message().unwrap();
1682
1683 let tool_use = agent_message
1684 .content
1685 .iter()
1686 .find_map(|content| match content {
1687 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1688 _ => None,
1689 })
1690 .expect("expected tool use in agent message");
1691
1692 let tool_result = agent_message
1693 .tool_results
1694 .get(&tool_use.id)
1695 .expect("expected tool result");
1696
1697 let result_text = match &tool_result.content {
1698 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
1699 _ => panic!("expected text content in tool result"),
1700 };
1701
1702 // "partial output" comes from FakeTerminalHandle's output field
1703 assert!(
1704 result_text.contains("partial output"),
1705 "expected tool result to contain terminal output, got: {result_text}"
1706 );
1707 // Match the actual format from process_content in terminal_tool.rs
1708 assert!(
1709 result_text.contains("The user stopped this command"),
1710 "expected tool result to indicate user stopped, got: {result_text}"
1711 );
1712 });
1713
1714 // Verify we can send a new message after cancellation
1715 verify_thread_recovery(&thread, &fake_model, cx).await;
1716}
1717
1718#[gpui::test]
1719async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
1720 // This test verifies that tools which properly handle cancellation via
1721 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
1722 // to cancellation and report that they were cancelled.
1723 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1724 always_allow_tools(cx);
1725 let fake_model = model.as_fake();
1726
1727 let (tool, was_cancelled) = CancellationAwareTool::new();
1728
1729 let mut events = thread
1730 .update(cx, |thread, cx| {
1731 thread.add_tool(tool);
1732 thread.send(
1733 UserMessageId::new(),
1734 ["call the cancellation aware tool"],
1735 cx,
1736 )
1737 })
1738 .unwrap();
1739
1740 cx.run_until_parked();
1741
1742 // Simulate the model calling the cancellation-aware tool
1743 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1744 LanguageModelToolUse {
1745 id: "cancellation_aware_1".into(),
1746 name: "cancellation_aware".into(),
1747 raw_input: r#"{}"#.into(),
1748 input: json!({}),
1749 is_input_complete: true,
1750 thought_signature: None,
1751 },
1752 ));
1753 fake_model.end_last_completion_stream();
1754
1755 cx.run_until_parked();
1756
1757 // Wait for the tool call to be reported
1758 let mut tool_started = false;
1759 let deadline = cx.executor().num_cpus() * 100;
1760 for _ in 0..deadline {
1761 cx.run_until_parked();
1762
1763 while let Some(Some(event)) = events.next().now_or_never() {
1764 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
1765 if tool_call.title == "Cancellation Aware Tool" {
1766 tool_started = true;
1767 break;
1768 }
1769 }
1770 }
1771
1772 if tool_started {
1773 break;
1774 }
1775
1776 cx.background_executor
1777 .timer(Duration::from_millis(10))
1778 .await;
1779 }
1780 assert!(tool_started, "expected cancellation aware tool to start");
1781
1782 // Cancel the thread and wait for it to complete
1783 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
1784
1785 // The cancel task should complete promptly because the tool handles cancellation
1786 let timeout = cx.background_executor.timer(Duration::from_secs(5));
1787 futures::select! {
1788 _ = cancel_task.fuse() => {}
1789 _ = timeout.fuse() => {
1790 panic!("cancel task timed out - tool did not respond to cancellation");
1791 }
1792 }
1793
1794 // Verify the tool detected cancellation via its flag
1795 assert!(
1796 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
1797 "tool should have detected cancellation via event_stream.cancelled_by_user()"
1798 );
1799
1800 // Collect remaining events
1801 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1802
1803 // Verify we got a cancellation stop event
1804 assert_eq!(
1805 stop_events(remaining_events),
1806 vec![acp::StopReason::Cancelled],
1807 );
1808
1809 // Verify we can send a new message after cancellation
1810 verify_thread_recovery(&thread, &fake_model, cx).await;
1811}
1812
1813/// Helper to verify thread can recover after cancellation by sending a simple message.
1814async fn verify_thread_recovery(
1815 thread: &Entity<Thread>,
1816 fake_model: &FakeLanguageModel,
1817 cx: &mut TestAppContext,
1818) {
1819 let events = thread
1820 .update(cx, |thread, cx| {
1821 thread.send(
1822 UserMessageId::new(),
1823 ["Testing: reply with 'Hello' then stop."],
1824 cx,
1825 )
1826 })
1827 .unwrap();
1828 cx.run_until_parked();
1829 fake_model.send_last_completion_stream_text_chunk("Hello");
1830 fake_model
1831 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1832 fake_model.end_last_completion_stream();
1833
1834 let events = events.collect::<Vec<_>>().await;
1835 thread.update(cx, |thread, _cx| {
1836 let message = thread.last_message().unwrap();
1837 let agent_message = message.as_agent_message().unwrap();
1838 assert_eq!(
1839 agent_message.content,
1840 vec![AgentMessageContent::Text("Hello".to_string())]
1841 );
1842 });
1843 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1844}
1845
1846/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
1847async fn wait_for_terminal_tool_started(
1848 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1849 cx: &mut TestAppContext,
1850) {
1851 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
1852 for _ in 0..deadline {
1853 cx.run_until_parked();
1854
1855 while let Some(Some(event)) = events.next().now_or_never() {
1856 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1857 update,
1858 ))) = &event
1859 {
1860 if update.fields.content.as_ref().is_some_and(|content| {
1861 content
1862 .iter()
1863 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1864 }) {
1865 return;
1866 }
1867 }
1868 }
1869
1870 cx.background_executor
1871 .timer(Duration::from_millis(10))
1872 .await;
1873 }
1874 panic!("terminal tool did not start within the expected time");
1875}
1876
1877/// Collects events until a Stop event is received, driving the executor to completion.
1878async fn collect_events_until_stop(
1879 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1880 cx: &mut TestAppContext,
1881) -> Vec<Result<ThreadEvent>> {
1882 let mut collected = Vec::new();
1883 let deadline = cx.executor().num_cpus() * 200;
1884
1885 for _ in 0..deadline {
1886 cx.executor().advance_clock(Duration::from_millis(10));
1887 cx.run_until_parked();
1888
1889 while let Some(Some(event)) = events.next().now_or_never() {
1890 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
1891 collected.push(event);
1892 if is_stop {
1893 return collected;
1894 }
1895 }
1896 }
1897 panic!(
1898 "did not receive Stop event within the expected time; collected {} events",
1899 collected.len()
1900 );
1901}
1902
1903#[gpui::test]
1904async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
1905 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1906 always_allow_tools(cx);
1907 let fake_model = model.as_fake();
1908
1909 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1910 let environment = Rc::new(FakeThreadEnvironment {
1911 handle: handle.clone(),
1912 });
1913
1914 let message_id = UserMessageId::new();
1915 let mut events = thread
1916 .update(cx, |thread, cx| {
1917 thread.add_tool(crate::TerminalTool::new(
1918 thread.project().clone(),
1919 environment,
1920 ));
1921 thread.send(message_id.clone(), ["run a command"], cx)
1922 })
1923 .unwrap();
1924
1925 cx.run_until_parked();
1926
1927 // Simulate the model calling the terminal tool
1928 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1929 LanguageModelToolUse {
1930 id: "terminal_tool_1".into(),
1931 name: "terminal".into(),
1932 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1933 input: json!({"command": "sleep 1000", "cd": "."}),
1934 is_input_complete: true,
1935 thought_signature: None,
1936 },
1937 ));
1938 fake_model.end_last_completion_stream();
1939
1940 // Wait for the terminal tool to start running
1941 wait_for_terminal_tool_started(&mut events, cx).await;
1942
1943 // Truncate the thread while the terminal is running
1944 thread
1945 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1946 .unwrap();
1947
1948 // Drive the executor to let cancellation complete
1949 let _ = collect_events_until_stop(&mut events, cx).await;
1950
1951 // Verify the terminal was killed
1952 assert!(
1953 handle.was_killed(),
1954 "expected terminal handle to be killed on truncate"
1955 );
1956
1957 // Verify the thread is empty after truncation
1958 thread.update(cx, |thread, _cx| {
1959 assert_eq!(
1960 thread.to_markdown(),
1961 "",
1962 "expected thread to be empty after truncating the only message"
1963 );
1964 });
1965
1966 // Verify we can send a new message after truncation
1967 verify_thread_recovery(&thread, &fake_model, cx).await;
1968}
1969
1970#[gpui::test]
1971async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
1972 // Tests that cancellation properly kills all running terminal tools when multiple are active.
1973 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1974 always_allow_tools(cx);
1975 let fake_model = model.as_fake();
1976
1977 let environment = Rc::new(MultiTerminalEnvironment::new());
1978
1979 let mut events = thread
1980 .update(cx, |thread, cx| {
1981 thread.add_tool(crate::TerminalTool::new(
1982 thread.project().clone(),
1983 environment.clone(),
1984 ));
1985 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
1986 })
1987 .unwrap();
1988
1989 cx.run_until_parked();
1990
1991 // Simulate the model calling two terminal tools
1992 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1993 LanguageModelToolUse {
1994 id: "terminal_tool_1".into(),
1995 name: "terminal".into(),
1996 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1997 input: json!({"command": "sleep 1000", "cd": "."}),
1998 is_input_complete: true,
1999 thought_signature: None,
2000 },
2001 ));
2002 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2003 LanguageModelToolUse {
2004 id: "terminal_tool_2".into(),
2005 name: "terminal".into(),
2006 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2007 input: json!({"command": "sleep 2000", "cd": "."}),
2008 is_input_complete: true,
2009 thought_signature: None,
2010 },
2011 ));
2012 fake_model.end_last_completion_stream();
2013
2014 // Wait for both terminal tools to start by counting terminal content updates
2015 let mut terminals_started = 0;
2016 let deadline = cx.executor().num_cpus() * 100;
2017 for _ in 0..deadline {
2018 cx.run_until_parked();
2019
2020 while let Some(Some(event)) = events.next().now_or_never() {
2021 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2022 update,
2023 ))) = &event
2024 {
2025 if update.fields.content.as_ref().is_some_and(|content| {
2026 content
2027 .iter()
2028 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2029 }) {
2030 terminals_started += 1;
2031 if terminals_started >= 2 {
2032 break;
2033 }
2034 }
2035 }
2036 }
2037 if terminals_started >= 2 {
2038 break;
2039 }
2040
2041 cx.background_executor
2042 .timer(Duration::from_millis(10))
2043 .await;
2044 }
2045 assert!(
2046 terminals_started >= 2,
2047 "expected 2 terminal tools to start, got {terminals_started}"
2048 );
2049
2050 // Cancel the thread while both terminals are running
2051 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2052
2053 // Collect remaining events
2054 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2055
2056 // Verify both terminal handles were killed
2057 let handles = environment.handles();
2058 assert_eq!(
2059 handles.len(),
2060 2,
2061 "expected 2 terminal handles to be created"
2062 );
2063 assert!(
2064 handles[0].was_killed(),
2065 "expected first terminal handle to be killed on cancellation"
2066 );
2067 assert!(
2068 handles[1].was_killed(),
2069 "expected second terminal handle to be killed on cancellation"
2070 );
2071
2072 // Verify we got a cancellation stop event
2073 assert_eq!(
2074 stop_events(remaining_events),
2075 vec![acp::StopReason::Cancelled],
2076 );
2077}
2078
2079#[gpui::test]
2080async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2081 // Tests that clicking the stop button on the terminal card (as opposed to the main
2082 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2083 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2084 always_allow_tools(cx);
2085 let fake_model = model.as_fake();
2086
2087 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2088 let environment = Rc::new(FakeThreadEnvironment {
2089 handle: handle.clone(),
2090 });
2091
2092 let mut events = thread
2093 .update(cx, |thread, cx| {
2094 thread.add_tool(crate::TerminalTool::new(
2095 thread.project().clone(),
2096 environment,
2097 ));
2098 thread.send(UserMessageId::new(), ["run a command"], cx)
2099 })
2100 .unwrap();
2101
2102 cx.run_until_parked();
2103
2104 // Simulate the model calling the terminal tool
2105 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2106 LanguageModelToolUse {
2107 id: "terminal_tool_1".into(),
2108 name: "terminal".into(),
2109 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2110 input: json!({"command": "sleep 1000", "cd": "."}),
2111 is_input_complete: true,
2112 thought_signature: None,
2113 },
2114 ));
2115 fake_model.end_last_completion_stream();
2116
2117 // Wait for the terminal tool to start running
2118 wait_for_terminal_tool_started(&mut events, cx).await;
2119
2120 // Simulate user clicking stop on the terminal card itself.
2121 // This sets the flag and signals exit (simulating what the real UI would do).
2122 handle.set_stopped_by_user(true);
2123 handle.killed.store(true, Ordering::SeqCst);
2124 handle.signal_exit();
2125
2126 // Wait for the tool to complete
2127 cx.run_until_parked();
2128
2129 // The thread continues after tool completion - simulate the model ending its turn
2130 fake_model
2131 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2132 fake_model.end_last_completion_stream();
2133
2134 // Collect remaining events
2135 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2136
2137 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2138 assert_eq!(
2139 stop_events(remaining_events),
2140 vec![acp::StopReason::EndTurn],
2141 );
2142
2143 // Verify the tool result indicates user stopped
2144 thread.update(cx, |thread, _cx| {
2145 let message = thread.last_message().unwrap();
2146 let agent_message = message.as_agent_message().unwrap();
2147
2148 let tool_use = agent_message
2149 .content
2150 .iter()
2151 .find_map(|content| match content {
2152 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2153 _ => None,
2154 })
2155 .expect("expected tool use in agent message");
2156
2157 let tool_result = agent_message
2158 .tool_results
2159 .get(&tool_use.id)
2160 .expect("expected tool result");
2161
2162 let result_text = match &tool_result.content {
2163 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2164 _ => panic!("expected text content in tool result"),
2165 };
2166
2167 assert!(
2168 result_text.contains("The user stopped this command"),
2169 "expected tool result to indicate user stopped, got: {result_text}"
2170 );
2171 });
2172}
2173
2174#[gpui::test]
2175async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2176 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2177 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2178 always_allow_tools(cx);
2179 let fake_model = model.as_fake();
2180
2181 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2182 let environment = Rc::new(FakeThreadEnvironment {
2183 handle: handle.clone(),
2184 });
2185
2186 let mut events = thread
2187 .update(cx, |thread, cx| {
2188 thread.add_tool(crate::TerminalTool::new(
2189 thread.project().clone(),
2190 environment,
2191 ));
2192 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2193 })
2194 .unwrap();
2195
2196 cx.run_until_parked();
2197
2198 // Simulate the model calling the terminal tool with a short timeout
2199 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2200 LanguageModelToolUse {
2201 id: "terminal_tool_1".into(),
2202 name: "terminal".into(),
2203 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2204 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2205 is_input_complete: true,
2206 thought_signature: None,
2207 },
2208 ));
2209 fake_model.end_last_completion_stream();
2210
2211 // Wait for the terminal tool to start running
2212 wait_for_terminal_tool_started(&mut events, cx).await;
2213
2214 // Advance clock past the timeout
2215 cx.executor().advance_clock(Duration::from_millis(200));
2216 cx.run_until_parked();
2217
2218 // The thread continues after tool completion - simulate the model ending its turn
2219 fake_model
2220 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2221 fake_model.end_last_completion_stream();
2222
2223 // Collect remaining events
2224 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2225
2226 // Verify the terminal was killed due to timeout
2227 assert!(
2228 handle.was_killed(),
2229 "expected terminal handle to be killed on timeout"
2230 );
2231
2232 // Verify we got an EndTurn (the tool completed, just with timeout)
2233 assert_eq!(
2234 stop_events(remaining_events),
2235 vec![acp::StopReason::EndTurn],
2236 );
2237
2238 // Verify the tool result indicates timeout, not user stopped
2239 thread.update(cx, |thread, _cx| {
2240 let message = thread.last_message().unwrap();
2241 let agent_message = message.as_agent_message().unwrap();
2242
2243 let tool_use = agent_message
2244 .content
2245 .iter()
2246 .find_map(|content| match content {
2247 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2248 _ => None,
2249 })
2250 .expect("expected tool use in agent message");
2251
2252 let tool_result = agent_message
2253 .tool_results
2254 .get(&tool_use.id)
2255 .expect("expected tool result");
2256
2257 let result_text = match &tool_result.content {
2258 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2259 _ => panic!("expected text content in tool result"),
2260 };
2261
2262 assert!(
2263 result_text.contains("timed out"),
2264 "expected tool result to indicate timeout, got: {result_text}"
2265 );
2266 assert!(
2267 !result_text.contains("The user stopped"),
2268 "tool result should not mention user stopped when it timed out, got: {result_text}"
2269 );
2270 });
2271}
2272
2273#[gpui::test]
2274async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2275 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2276 let fake_model = model.as_fake();
2277
2278 let events_1 = thread
2279 .update(cx, |thread, cx| {
2280 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2281 })
2282 .unwrap();
2283 cx.run_until_parked();
2284 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2285 cx.run_until_parked();
2286
2287 let events_2 = thread
2288 .update(cx, |thread, cx| {
2289 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2290 })
2291 .unwrap();
2292 cx.run_until_parked();
2293 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2294 fake_model
2295 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2296 fake_model.end_last_completion_stream();
2297
2298 let events_1 = events_1.collect::<Vec<_>>().await;
2299 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2300 let events_2 = events_2.collect::<Vec<_>>().await;
2301 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2302}
2303
2304#[gpui::test]
2305async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2306 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2307 let fake_model = model.as_fake();
2308
2309 let events_1 = thread
2310 .update(cx, |thread, cx| {
2311 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2312 })
2313 .unwrap();
2314 cx.run_until_parked();
2315 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2316 fake_model
2317 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2318 fake_model.end_last_completion_stream();
2319 let events_1 = events_1.collect::<Vec<_>>().await;
2320
2321 let events_2 = thread
2322 .update(cx, |thread, cx| {
2323 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2324 })
2325 .unwrap();
2326 cx.run_until_parked();
2327 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2328 fake_model
2329 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2330 fake_model.end_last_completion_stream();
2331 let events_2 = events_2.collect::<Vec<_>>().await;
2332
2333 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2334 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2335}
2336
2337#[gpui::test]
2338async fn test_refusal(cx: &mut TestAppContext) {
2339 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2340 let fake_model = model.as_fake();
2341
2342 let events = thread
2343 .update(cx, |thread, cx| {
2344 thread.send(UserMessageId::new(), ["Hello"], cx)
2345 })
2346 .unwrap();
2347 cx.run_until_parked();
2348 thread.read_with(cx, |thread, _| {
2349 assert_eq!(
2350 thread.to_markdown(),
2351 indoc! {"
2352 ## User
2353
2354 Hello
2355 "}
2356 );
2357 });
2358
2359 fake_model.send_last_completion_stream_text_chunk("Hey!");
2360 cx.run_until_parked();
2361 thread.read_with(cx, |thread, _| {
2362 assert_eq!(
2363 thread.to_markdown(),
2364 indoc! {"
2365 ## User
2366
2367 Hello
2368
2369 ## Assistant
2370
2371 Hey!
2372 "}
2373 );
2374 });
2375
2376 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2377 fake_model
2378 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2379 let events = events.collect::<Vec<_>>().await;
2380 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2381 thread.read_with(cx, |thread, _| {
2382 assert_eq!(thread.to_markdown(), "");
2383 });
2384}
2385
2386#[gpui::test]
2387async fn test_truncate_first_message(cx: &mut TestAppContext) {
2388 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2389 let fake_model = model.as_fake();
2390
2391 let message_id = UserMessageId::new();
2392 thread
2393 .update(cx, |thread, cx| {
2394 thread.send(message_id.clone(), ["Hello"], cx)
2395 })
2396 .unwrap();
2397 cx.run_until_parked();
2398 thread.read_with(cx, |thread, _| {
2399 assert_eq!(
2400 thread.to_markdown(),
2401 indoc! {"
2402 ## User
2403
2404 Hello
2405 "}
2406 );
2407 assert_eq!(thread.latest_token_usage(), None);
2408 });
2409
2410 fake_model.send_last_completion_stream_text_chunk("Hey!");
2411 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2412 language_model::TokenUsage {
2413 input_tokens: 32_000,
2414 output_tokens: 16_000,
2415 cache_creation_input_tokens: 0,
2416 cache_read_input_tokens: 0,
2417 },
2418 ));
2419 cx.run_until_parked();
2420 thread.read_with(cx, |thread, _| {
2421 assert_eq!(
2422 thread.to_markdown(),
2423 indoc! {"
2424 ## User
2425
2426 Hello
2427
2428 ## Assistant
2429
2430 Hey!
2431 "}
2432 );
2433 assert_eq!(
2434 thread.latest_token_usage(),
2435 Some(acp_thread::TokenUsage {
2436 used_tokens: 32_000 + 16_000,
2437 max_tokens: 1_000_000,
2438 input_tokens: 32_000,
2439 output_tokens: 16_000,
2440 })
2441 );
2442 });
2443
2444 thread
2445 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2446 .unwrap();
2447 cx.run_until_parked();
2448 thread.read_with(cx, |thread, _| {
2449 assert_eq!(thread.to_markdown(), "");
2450 assert_eq!(thread.latest_token_usage(), None);
2451 });
2452
2453 // Ensure we can still send a new message after truncation.
2454 thread
2455 .update(cx, |thread, cx| {
2456 thread.send(UserMessageId::new(), ["Hi"], cx)
2457 })
2458 .unwrap();
2459 thread.update(cx, |thread, _cx| {
2460 assert_eq!(
2461 thread.to_markdown(),
2462 indoc! {"
2463 ## User
2464
2465 Hi
2466 "}
2467 );
2468 });
2469 cx.run_until_parked();
2470 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2471 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2472 language_model::TokenUsage {
2473 input_tokens: 40_000,
2474 output_tokens: 20_000,
2475 cache_creation_input_tokens: 0,
2476 cache_read_input_tokens: 0,
2477 },
2478 ));
2479 cx.run_until_parked();
2480 thread.read_with(cx, |thread, _| {
2481 assert_eq!(
2482 thread.to_markdown(),
2483 indoc! {"
2484 ## User
2485
2486 Hi
2487
2488 ## Assistant
2489
2490 Ahoy!
2491 "}
2492 );
2493
2494 assert_eq!(
2495 thread.latest_token_usage(),
2496 Some(acp_thread::TokenUsage {
2497 used_tokens: 40_000 + 20_000,
2498 max_tokens: 1_000_000,
2499 input_tokens: 40_000,
2500 output_tokens: 20_000,
2501 })
2502 );
2503 });
2504}
2505
2506#[gpui::test]
2507async fn test_truncate_second_message(cx: &mut TestAppContext) {
2508 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2509 let fake_model = model.as_fake();
2510
2511 thread
2512 .update(cx, |thread, cx| {
2513 thread.send(UserMessageId::new(), ["Message 1"], cx)
2514 })
2515 .unwrap();
2516 cx.run_until_parked();
2517 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2518 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2519 language_model::TokenUsage {
2520 input_tokens: 32_000,
2521 output_tokens: 16_000,
2522 cache_creation_input_tokens: 0,
2523 cache_read_input_tokens: 0,
2524 },
2525 ));
2526 fake_model.end_last_completion_stream();
2527 cx.run_until_parked();
2528
2529 let assert_first_message_state = |cx: &mut TestAppContext| {
2530 thread.clone().read_with(cx, |thread, _| {
2531 assert_eq!(
2532 thread.to_markdown(),
2533 indoc! {"
2534 ## User
2535
2536 Message 1
2537
2538 ## Assistant
2539
2540 Message 1 response
2541 "}
2542 );
2543
2544 assert_eq!(
2545 thread.latest_token_usage(),
2546 Some(acp_thread::TokenUsage {
2547 used_tokens: 32_000 + 16_000,
2548 max_tokens: 1_000_000,
2549 input_tokens: 32_000,
2550 output_tokens: 16_000,
2551 })
2552 );
2553 });
2554 };
2555
2556 assert_first_message_state(cx);
2557
2558 let second_message_id = UserMessageId::new();
2559 thread
2560 .update(cx, |thread, cx| {
2561 thread.send(second_message_id.clone(), ["Message 2"], cx)
2562 })
2563 .unwrap();
2564 cx.run_until_parked();
2565
2566 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2567 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2568 language_model::TokenUsage {
2569 input_tokens: 40_000,
2570 output_tokens: 20_000,
2571 cache_creation_input_tokens: 0,
2572 cache_read_input_tokens: 0,
2573 },
2574 ));
2575 fake_model.end_last_completion_stream();
2576 cx.run_until_parked();
2577
2578 thread.read_with(cx, |thread, _| {
2579 assert_eq!(
2580 thread.to_markdown(),
2581 indoc! {"
2582 ## User
2583
2584 Message 1
2585
2586 ## Assistant
2587
2588 Message 1 response
2589
2590 ## User
2591
2592 Message 2
2593
2594 ## Assistant
2595
2596 Message 2 response
2597 "}
2598 );
2599
2600 assert_eq!(
2601 thread.latest_token_usage(),
2602 Some(acp_thread::TokenUsage {
2603 used_tokens: 40_000 + 20_000,
2604 max_tokens: 1_000_000,
2605 input_tokens: 40_000,
2606 output_tokens: 20_000,
2607 })
2608 );
2609 });
2610
2611 thread
2612 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2613 .unwrap();
2614 cx.run_until_parked();
2615
2616 assert_first_message_state(cx);
2617}
2618
2619#[gpui::test]
2620async fn test_title_generation(cx: &mut TestAppContext) {
2621 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2622 let fake_model = model.as_fake();
2623
2624 let summary_model = Arc::new(FakeLanguageModel::default());
2625 thread.update(cx, |thread, cx| {
2626 thread.set_summarization_model(Some(summary_model.clone()), cx)
2627 });
2628
2629 let send = thread
2630 .update(cx, |thread, cx| {
2631 thread.send(UserMessageId::new(), ["Hello"], cx)
2632 })
2633 .unwrap();
2634 cx.run_until_parked();
2635
2636 fake_model.send_last_completion_stream_text_chunk("Hey!");
2637 fake_model.end_last_completion_stream();
2638 cx.run_until_parked();
2639 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2640
2641 // Ensure the summary model has been invoked to generate a title.
2642 summary_model.send_last_completion_stream_text_chunk("Hello ");
2643 summary_model.send_last_completion_stream_text_chunk("world\nG");
2644 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2645 summary_model.end_last_completion_stream();
2646 send.collect::<Vec<_>>().await;
2647 cx.run_until_parked();
2648 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2649
2650 // Send another message, ensuring no title is generated this time.
2651 let send = thread
2652 .update(cx, |thread, cx| {
2653 thread.send(UserMessageId::new(), ["Hello again"], cx)
2654 })
2655 .unwrap();
2656 cx.run_until_parked();
2657 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2658 fake_model.end_last_completion_stream();
2659 cx.run_until_parked();
2660 assert_eq!(summary_model.pending_completions(), Vec::new());
2661 send.collect::<Vec<_>>().await;
2662 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2663}
2664
2665#[gpui::test]
2666async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2667 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2668 let fake_model = model.as_fake();
2669
2670 let _events = thread
2671 .update(cx, |thread, cx| {
2672 thread.add_tool(ToolRequiringPermission);
2673 thread.add_tool(EchoTool);
2674 thread.send(UserMessageId::new(), ["Hey!"], cx)
2675 })
2676 .unwrap();
2677 cx.run_until_parked();
2678
2679 let permission_tool_use = LanguageModelToolUse {
2680 id: "tool_id_1".into(),
2681 name: ToolRequiringPermission::name().into(),
2682 raw_input: "{}".into(),
2683 input: json!({}),
2684 is_input_complete: true,
2685 thought_signature: None,
2686 };
2687 let echo_tool_use = LanguageModelToolUse {
2688 id: "tool_id_2".into(),
2689 name: EchoTool::name().into(),
2690 raw_input: json!({"text": "test"}).to_string(),
2691 input: json!({"text": "test"}),
2692 is_input_complete: true,
2693 thought_signature: None,
2694 };
2695 fake_model.send_last_completion_stream_text_chunk("Hi!");
2696 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2697 permission_tool_use,
2698 ));
2699 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2700 echo_tool_use.clone(),
2701 ));
2702 fake_model.end_last_completion_stream();
2703 cx.run_until_parked();
2704
2705 // Ensure pending tools are skipped when building a request.
2706 let request = thread
2707 .read_with(cx, |thread, cx| {
2708 thread.build_completion_request(CompletionIntent::EditFile, cx)
2709 })
2710 .unwrap();
2711 assert_eq!(
2712 request.messages[1..],
2713 vec![
2714 LanguageModelRequestMessage {
2715 role: Role::User,
2716 content: vec!["Hey!".into()],
2717 cache: true,
2718 reasoning_details: None,
2719 },
2720 LanguageModelRequestMessage {
2721 role: Role::Assistant,
2722 content: vec![
2723 MessageContent::Text("Hi!".into()),
2724 MessageContent::ToolUse(echo_tool_use.clone())
2725 ],
2726 cache: false,
2727 reasoning_details: None,
2728 },
2729 LanguageModelRequestMessage {
2730 role: Role::User,
2731 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2732 tool_use_id: echo_tool_use.id.clone(),
2733 tool_name: echo_tool_use.name,
2734 is_error: false,
2735 content: "test".into(),
2736 output: Some("test".into())
2737 })],
2738 cache: false,
2739 reasoning_details: None,
2740 },
2741 ],
2742 );
2743}
2744
2745#[gpui::test]
2746async fn test_agent_connection(cx: &mut TestAppContext) {
2747 cx.update(settings::init);
2748 let templates = Templates::new();
2749
2750 // Initialize language model system with test provider
2751 cx.update(|cx| {
2752 gpui_tokio::init(cx);
2753
2754 let http_client = FakeHttpClient::with_404_response();
2755 let clock = Arc::new(clock::FakeSystemClock::new());
2756 let client = Client::new(clock, http_client, cx);
2757 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2758 language_model::init(client.clone(), cx);
2759 language_models::init(user_store, client.clone(), cx);
2760 LanguageModelRegistry::test(cx);
2761 });
2762 cx.executor().forbid_parking();
2763
2764 // Create a project for new_thread
2765 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2766 fake_fs.insert_tree(path!("/test"), json!({})).await;
2767 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2768 let cwd = Path::new("/test");
2769 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2770
2771 // Create agent and connection
2772 let agent = NativeAgent::new(
2773 project.clone(),
2774 thread_store,
2775 templates.clone(),
2776 None,
2777 fake_fs.clone(),
2778 &mut cx.to_async(),
2779 )
2780 .await
2781 .unwrap();
2782 let connection = NativeAgentConnection(agent.clone());
2783
2784 // Create a thread using new_thread
2785 let connection_rc = Rc::new(connection.clone());
2786 let acp_thread = cx
2787 .update(|cx| connection_rc.new_thread(project, cwd, cx))
2788 .await
2789 .expect("new_thread should succeed");
2790
2791 // Get the session_id from the AcpThread
2792 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2793
2794 // Test model_selector returns Some
2795 let selector_opt = connection.model_selector(&session_id);
2796 assert!(
2797 selector_opt.is_some(),
2798 "agent should always support ModelSelector"
2799 );
2800 let selector = selector_opt.unwrap();
2801
2802 // Test list_models
2803 let listed_models = cx
2804 .update(|cx| selector.list_models(cx))
2805 .await
2806 .expect("list_models should succeed");
2807 let AgentModelList::Grouped(listed_models) = listed_models else {
2808 panic!("Unexpected model list type");
2809 };
2810 assert!(!listed_models.is_empty(), "should have at least one model");
2811 assert_eq!(
2812 listed_models[&AgentModelGroupName("Fake".into())][0]
2813 .id
2814 .0
2815 .as_ref(),
2816 "fake/fake"
2817 );
2818
2819 // Test selected_model returns the default
2820 let model = cx
2821 .update(|cx| selector.selected_model(cx))
2822 .await
2823 .expect("selected_model should succeed");
2824 let model = cx
2825 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2826 .unwrap();
2827 let model = model.as_fake();
2828 assert_eq!(model.id().0, "fake", "should return default model");
2829
2830 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2831 cx.run_until_parked();
2832 model.send_last_completion_stream_text_chunk("def");
2833 cx.run_until_parked();
2834 acp_thread.read_with(cx, |thread, cx| {
2835 assert_eq!(
2836 thread.to_markdown(cx),
2837 indoc! {"
2838 ## User
2839
2840 abc
2841
2842 ## Assistant
2843
2844 def
2845
2846 "}
2847 )
2848 });
2849
2850 // Test cancel
2851 cx.update(|cx| connection.cancel(&session_id, cx));
2852 request.await.expect("prompt should fail gracefully");
2853
2854 // Ensure that dropping the ACP thread causes the native thread to be
2855 // dropped as well.
2856 cx.update(|_| drop(acp_thread));
2857 let result = cx
2858 .update(|cx| {
2859 connection.prompt(
2860 Some(acp_thread::UserMessageId::new()),
2861 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2862 cx,
2863 )
2864 })
2865 .await;
2866 assert_eq!(
2867 result.as_ref().unwrap_err().to_string(),
2868 "Session not found",
2869 "unexpected result: {:?}",
2870 result
2871 );
2872}
2873
2874#[gpui::test]
2875async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2876 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2877 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2878 let fake_model = model.as_fake();
2879
2880 let mut events = thread
2881 .update(cx, |thread, cx| {
2882 thread.send(UserMessageId::new(), ["Think"], cx)
2883 })
2884 .unwrap();
2885 cx.run_until_parked();
2886
2887 // Simulate streaming partial input.
2888 let input = json!({});
2889 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2890 LanguageModelToolUse {
2891 id: "1".into(),
2892 name: ThinkingTool::name().into(),
2893 raw_input: input.to_string(),
2894 input,
2895 is_input_complete: false,
2896 thought_signature: None,
2897 },
2898 ));
2899
2900 // Input streaming completed
2901 let input = json!({ "content": "Thinking hard!" });
2902 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2903 LanguageModelToolUse {
2904 id: "1".into(),
2905 name: "thinking".into(),
2906 raw_input: input.to_string(),
2907 input,
2908 is_input_complete: true,
2909 thought_signature: None,
2910 },
2911 ));
2912 fake_model.end_last_completion_stream();
2913 cx.run_until_parked();
2914
2915 let tool_call = expect_tool_call(&mut events).await;
2916 assert_eq!(
2917 tool_call,
2918 acp::ToolCall::new("1", "Thinking")
2919 .kind(acp::ToolKind::Think)
2920 .raw_input(json!({}))
2921 .meta(acp::Meta::from_iter([(
2922 "tool_name".into(),
2923 "thinking".into()
2924 )]))
2925 );
2926 let update = expect_tool_call_update_fields(&mut events).await;
2927 assert_eq!(
2928 update,
2929 acp::ToolCallUpdate::new(
2930 "1",
2931 acp::ToolCallUpdateFields::new()
2932 .title("Thinking")
2933 .kind(acp::ToolKind::Think)
2934 .raw_input(json!({ "content": "Thinking hard!"}))
2935 )
2936 );
2937 let update = expect_tool_call_update_fields(&mut events).await;
2938 assert_eq!(
2939 update,
2940 acp::ToolCallUpdate::new(
2941 "1",
2942 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
2943 )
2944 );
2945 let update = expect_tool_call_update_fields(&mut events).await;
2946 assert_eq!(
2947 update,
2948 acp::ToolCallUpdate::new(
2949 "1",
2950 acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
2951 )
2952 );
2953 let update = expect_tool_call_update_fields(&mut events).await;
2954 assert_eq!(
2955 update,
2956 acp::ToolCallUpdate::new(
2957 "1",
2958 acp::ToolCallUpdateFields::new()
2959 .status(acp::ToolCallStatus::Completed)
2960 .raw_output("Finished thinking.")
2961 )
2962 );
2963}
2964
2965#[gpui::test]
2966async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2967 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2968 let fake_model = model.as_fake();
2969
2970 let mut events = thread
2971 .update(cx, |thread, cx| {
2972 thread.send(UserMessageId::new(), ["Hello!"], cx)
2973 })
2974 .unwrap();
2975 cx.run_until_parked();
2976
2977 fake_model.send_last_completion_stream_text_chunk("Hey!");
2978 fake_model.end_last_completion_stream();
2979
2980 let mut retry_events = Vec::new();
2981 while let Some(Ok(event)) = events.next().await {
2982 match event {
2983 ThreadEvent::Retry(retry_status) => {
2984 retry_events.push(retry_status);
2985 }
2986 ThreadEvent::Stop(..) => break,
2987 _ => {}
2988 }
2989 }
2990
2991 assert_eq!(retry_events.len(), 0);
2992 thread.read_with(cx, |thread, _cx| {
2993 assert_eq!(
2994 thread.to_markdown(),
2995 indoc! {"
2996 ## User
2997
2998 Hello!
2999
3000 ## Assistant
3001
3002 Hey!
3003 "}
3004 )
3005 });
3006}
3007
3008#[gpui::test]
3009async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3010 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3011 let fake_model = model.as_fake();
3012
3013 let mut events = thread
3014 .update(cx, |thread, cx| {
3015 thread.send(UserMessageId::new(), ["Hello!"], cx)
3016 })
3017 .unwrap();
3018 cx.run_until_parked();
3019
3020 fake_model.send_last_completion_stream_text_chunk("Hey,");
3021 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3022 provider: LanguageModelProviderName::new("Anthropic"),
3023 retry_after: Some(Duration::from_secs(3)),
3024 });
3025 fake_model.end_last_completion_stream();
3026
3027 cx.executor().advance_clock(Duration::from_secs(3));
3028 cx.run_until_parked();
3029
3030 fake_model.send_last_completion_stream_text_chunk("there!");
3031 fake_model.end_last_completion_stream();
3032 cx.run_until_parked();
3033
3034 let mut retry_events = Vec::new();
3035 while let Some(Ok(event)) = events.next().await {
3036 match event {
3037 ThreadEvent::Retry(retry_status) => {
3038 retry_events.push(retry_status);
3039 }
3040 ThreadEvent::Stop(..) => break,
3041 _ => {}
3042 }
3043 }
3044
3045 assert_eq!(retry_events.len(), 1);
3046 assert!(matches!(
3047 retry_events[0],
3048 acp_thread::RetryStatus { attempt: 1, .. }
3049 ));
3050 thread.read_with(cx, |thread, _cx| {
3051 assert_eq!(
3052 thread.to_markdown(),
3053 indoc! {"
3054 ## User
3055
3056 Hello!
3057
3058 ## Assistant
3059
3060 Hey,
3061
3062 [resume]
3063
3064 ## Assistant
3065
3066 there!
3067 "}
3068 )
3069 });
3070}
3071
3072#[gpui::test]
3073async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3074 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3075 let fake_model = model.as_fake();
3076
3077 let events = thread
3078 .update(cx, |thread, cx| {
3079 thread.add_tool(EchoTool);
3080 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3081 })
3082 .unwrap();
3083 cx.run_until_parked();
3084
3085 let tool_use_1 = LanguageModelToolUse {
3086 id: "tool_1".into(),
3087 name: EchoTool::name().into(),
3088 raw_input: json!({"text": "test"}).to_string(),
3089 input: json!({"text": "test"}),
3090 is_input_complete: true,
3091 thought_signature: None,
3092 };
3093 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3094 tool_use_1.clone(),
3095 ));
3096 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3097 provider: LanguageModelProviderName::new("Anthropic"),
3098 retry_after: Some(Duration::from_secs(3)),
3099 });
3100 fake_model.end_last_completion_stream();
3101
3102 cx.executor().advance_clock(Duration::from_secs(3));
3103 let completion = fake_model.pending_completions().pop().unwrap();
3104 assert_eq!(
3105 completion.messages[1..],
3106 vec![
3107 LanguageModelRequestMessage {
3108 role: Role::User,
3109 content: vec!["Call the echo tool!".into()],
3110 cache: false,
3111 reasoning_details: None,
3112 },
3113 LanguageModelRequestMessage {
3114 role: Role::Assistant,
3115 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3116 cache: false,
3117 reasoning_details: None,
3118 },
3119 LanguageModelRequestMessage {
3120 role: Role::User,
3121 content: vec![language_model::MessageContent::ToolResult(
3122 LanguageModelToolResult {
3123 tool_use_id: tool_use_1.id.clone(),
3124 tool_name: tool_use_1.name.clone(),
3125 is_error: false,
3126 content: "test".into(),
3127 output: Some("test".into())
3128 }
3129 )],
3130 cache: true,
3131 reasoning_details: None,
3132 },
3133 ]
3134 );
3135
3136 fake_model.send_last_completion_stream_text_chunk("Done");
3137 fake_model.end_last_completion_stream();
3138 cx.run_until_parked();
3139 events.collect::<Vec<_>>().await;
3140 thread.read_with(cx, |thread, _cx| {
3141 assert_eq!(
3142 thread.last_message(),
3143 Some(Message::Agent(AgentMessage {
3144 content: vec![AgentMessageContent::Text("Done".into())],
3145 tool_results: IndexMap::default(),
3146 reasoning_details: None,
3147 }))
3148 );
3149 })
3150}
3151
3152#[gpui::test]
3153async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3154 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3155 let fake_model = model.as_fake();
3156
3157 let mut events = thread
3158 .update(cx, |thread, cx| {
3159 thread.send(UserMessageId::new(), ["Hello!"], cx)
3160 })
3161 .unwrap();
3162 cx.run_until_parked();
3163
3164 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3165 fake_model.send_last_completion_stream_error(
3166 LanguageModelCompletionError::ServerOverloaded {
3167 provider: LanguageModelProviderName::new("Anthropic"),
3168 retry_after: Some(Duration::from_secs(3)),
3169 },
3170 );
3171 fake_model.end_last_completion_stream();
3172 cx.executor().advance_clock(Duration::from_secs(3));
3173 cx.run_until_parked();
3174 }
3175
3176 let mut errors = Vec::new();
3177 let mut retry_events = Vec::new();
3178 while let Some(event) = events.next().await {
3179 match event {
3180 Ok(ThreadEvent::Retry(retry_status)) => {
3181 retry_events.push(retry_status);
3182 }
3183 Ok(ThreadEvent::Stop(..)) => break,
3184 Err(error) => errors.push(error),
3185 _ => {}
3186 }
3187 }
3188
3189 assert_eq!(
3190 retry_events.len(),
3191 crate::thread::MAX_RETRY_ATTEMPTS as usize
3192 );
3193 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3194 assert_eq!(retry_events[i].attempt, i + 1);
3195 }
3196 assert_eq!(errors.len(), 1);
3197 let error = errors[0]
3198 .downcast_ref::<LanguageModelCompletionError>()
3199 .unwrap();
3200 assert!(matches!(
3201 error,
3202 LanguageModelCompletionError::ServerOverloaded { .. }
3203 ));
3204}
3205
3206/// Filters out the stop events for asserting against in tests
3207fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3208 result_events
3209 .into_iter()
3210 .filter_map(|event| match event.unwrap() {
3211 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3212 _ => None,
3213 })
3214 .collect()
3215}
3216
3217struct ThreadTest {
3218 model: Arc<dyn LanguageModel>,
3219 thread: Entity<Thread>,
3220 project_context: Entity<ProjectContext>,
3221 context_server_store: Entity<ContextServerStore>,
3222 fs: Arc<FakeFs>,
3223}
3224
3225enum TestModel {
3226 Sonnet4,
3227 Fake,
3228}
3229
3230impl TestModel {
3231 fn id(&self) -> LanguageModelId {
3232 match self {
3233 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3234 TestModel::Fake => unreachable!(),
3235 }
3236 }
3237}
3238
3239async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3240 cx.executor().allow_parking();
3241
3242 let fs = FakeFs::new(cx.background_executor.clone());
3243 fs.create_dir(paths::settings_file().parent().unwrap())
3244 .await
3245 .unwrap();
3246 fs.insert_file(
3247 paths::settings_file(),
3248 json!({
3249 "agent": {
3250 "default_profile": "test-profile",
3251 "profiles": {
3252 "test-profile": {
3253 "name": "Test Profile",
3254 "tools": {
3255 EchoTool::name(): true,
3256 DelayTool::name(): true,
3257 WordListTool::name(): true,
3258 ToolRequiringPermission::name(): true,
3259 InfiniteTool::name(): true,
3260 CancellationAwareTool::name(): true,
3261 ThinkingTool::name(): true,
3262 "terminal": true,
3263 }
3264 }
3265 }
3266 }
3267 })
3268 .to_string()
3269 .into_bytes(),
3270 )
3271 .await;
3272
3273 cx.update(|cx| {
3274 settings::init(cx);
3275
3276 match model {
3277 TestModel::Fake => {}
3278 TestModel::Sonnet4 => {
3279 gpui_tokio::init(cx);
3280 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3281 cx.set_http_client(Arc::new(http_client));
3282 let client = Client::production(cx);
3283 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3284 language_model::init(client.clone(), cx);
3285 language_models::init(user_store, client.clone(), cx);
3286 }
3287 };
3288
3289 watch_settings(fs.clone(), cx);
3290 });
3291
3292 let templates = Templates::new();
3293
3294 fs.insert_tree(path!("/test"), json!({})).await;
3295 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3296
3297 let model = cx
3298 .update(|cx| {
3299 if let TestModel::Fake = model {
3300 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3301 } else {
3302 let model_id = model.id();
3303 let models = LanguageModelRegistry::read_global(cx);
3304 let model = models
3305 .available_models(cx)
3306 .find(|model| model.id() == model_id)
3307 .unwrap();
3308
3309 let provider = models.provider(&model.provider_id()).unwrap();
3310 let authenticated = provider.authenticate(cx);
3311
3312 cx.spawn(async move |_cx| {
3313 authenticated.await.unwrap();
3314 model
3315 })
3316 }
3317 })
3318 .await;
3319
3320 let project_context = cx.new(|_cx| ProjectContext::default());
3321 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3322 let context_server_registry =
3323 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3324 let thread = cx.new(|cx| {
3325 Thread::new(
3326 project,
3327 project_context.clone(),
3328 context_server_registry,
3329 templates,
3330 Some(model.clone()),
3331 cx,
3332 )
3333 });
3334 ThreadTest {
3335 model,
3336 thread,
3337 project_context,
3338 context_server_store,
3339 fs,
3340 }
3341}
3342
3343#[cfg(test)]
3344#[ctor::ctor]
3345fn init_logger() {
3346 if std::env::var("RUST_LOG").is_ok() {
3347 env_logger::init();
3348 }
3349}
3350
3351fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3352 let fs = fs.clone();
3353 cx.spawn({
3354 async move |cx| {
3355 let mut new_settings_content_rx = settings::watch_config_file(
3356 cx.background_executor(),
3357 fs,
3358 paths::settings_file().clone(),
3359 );
3360
3361 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3362 cx.update(|cx| {
3363 SettingsStore::update_global(cx, |settings, cx| {
3364 settings.set_user_settings(&new_settings_content, cx)
3365 })
3366 })
3367 .ok();
3368 }
3369 }
3370 })
3371 .detach();
3372}
3373
3374fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3375 completion
3376 .tools
3377 .iter()
3378 .map(|tool| tool.name.clone())
3379 .collect()
3380}
3381
3382fn setup_context_server(
3383 name: &'static str,
3384 tools: Vec<context_server::types::Tool>,
3385 context_server_store: &Entity<ContextServerStore>,
3386 cx: &mut TestAppContext,
3387) -> mpsc::UnboundedReceiver<(
3388 context_server::types::CallToolParams,
3389 oneshot::Sender<context_server::types::CallToolResponse>,
3390)> {
3391 cx.update(|cx| {
3392 let mut settings = ProjectSettings::get_global(cx).clone();
3393 settings.context_servers.insert(
3394 name.into(),
3395 project::project_settings::ContextServerSettings::Stdio {
3396 enabled: true,
3397 remote: false,
3398 command: ContextServerCommand {
3399 path: "somebinary".into(),
3400 args: Vec::new(),
3401 env: None,
3402 timeout: None,
3403 },
3404 },
3405 );
3406 ProjectSettings::override_global(settings, cx);
3407 });
3408
3409 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3410 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3411 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3412 context_server::types::InitializeResponse {
3413 protocol_version: context_server::types::ProtocolVersion(
3414 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3415 ),
3416 server_info: context_server::types::Implementation {
3417 name: name.into(),
3418 version: "1.0.0".to_string(),
3419 },
3420 capabilities: context_server::types::ServerCapabilities {
3421 tools: Some(context_server::types::ToolsCapabilities {
3422 list_changed: Some(true),
3423 }),
3424 ..Default::default()
3425 },
3426 meta: None,
3427 }
3428 })
3429 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3430 let tools = tools.clone();
3431 async move {
3432 context_server::types::ListToolsResponse {
3433 tools,
3434 next_cursor: None,
3435 meta: None,
3436 }
3437 }
3438 })
3439 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3440 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3441 async move {
3442 let (response_tx, response_rx) = oneshot::channel();
3443 mcp_tool_calls_tx
3444 .unbounded_send((params, response_tx))
3445 .unwrap();
3446 response_rx.await.unwrap()
3447 }
3448 });
3449 context_server_store.update(cx, |store, cx| {
3450 store.start_server(
3451 Arc::new(ContextServer::new(
3452 ContextServerId(name.into()),
3453 Arc::new(fake_transport),
3454 )),
3455 cx,
3456 );
3457 });
3458 cx.run_until_parked();
3459 mcp_tool_calls_rx
3460}
3461
3462#[gpui::test]
3463async fn test_tokens_before_message(cx: &mut TestAppContext) {
3464 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3465 let fake_model = model.as_fake();
3466
3467 // First message
3468 let message_1_id = UserMessageId::new();
3469 thread
3470 .update(cx, |thread, cx| {
3471 thread.send(message_1_id.clone(), ["First message"], cx)
3472 })
3473 .unwrap();
3474 cx.run_until_parked();
3475
3476 // Before any response, tokens_before_message should return None for first message
3477 thread.read_with(cx, |thread, _| {
3478 assert_eq!(
3479 thread.tokens_before_message(&message_1_id),
3480 None,
3481 "First message should have no tokens before it"
3482 );
3483 });
3484
3485 // Complete first message with usage
3486 fake_model.send_last_completion_stream_text_chunk("Response 1");
3487 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3488 language_model::TokenUsage {
3489 input_tokens: 100,
3490 output_tokens: 50,
3491 cache_creation_input_tokens: 0,
3492 cache_read_input_tokens: 0,
3493 },
3494 ));
3495 fake_model.end_last_completion_stream();
3496 cx.run_until_parked();
3497
3498 // First message still has no tokens before it
3499 thread.read_with(cx, |thread, _| {
3500 assert_eq!(
3501 thread.tokens_before_message(&message_1_id),
3502 None,
3503 "First message should still have no tokens before it after response"
3504 );
3505 });
3506
3507 // Second message
3508 let message_2_id = UserMessageId::new();
3509 thread
3510 .update(cx, |thread, cx| {
3511 thread.send(message_2_id.clone(), ["Second message"], cx)
3512 })
3513 .unwrap();
3514 cx.run_until_parked();
3515
3516 // Second message should have first message's input tokens before it
3517 thread.read_with(cx, |thread, _| {
3518 assert_eq!(
3519 thread.tokens_before_message(&message_2_id),
3520 Some(100),
3521 "Second message should have 100 tokens before it (from first request)"
3522 );
3523 });
3524
3525 // Complete second message
3526 fake_model.send_last_completion_stream_text_chunk("Response 2");
3527 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3528 language_model::TokenUsage {
3529 input_tokens: 250, // Total for this request (includes previous context)
3530 output_tokens: 75,
3531 cache_creation_input_tokens: 0,
3532 cache_read_input_tokens: 0,
3533 },
3534 ));
3535 fake_model.end_last_completion_stream();
3536 cx.run_until_parked();
3537
3538 // Third message
3539 let message_3_id = UserMessageId::new();
3540 thread
3541 .update(cx, |thread, cx| {
3542 thread.send(message_3_id.clone(), ["Third message"], cx)
3543 })
3544 .unwrap();
3545 cx.run_until_parked();
3546
3547 // Third message should have second message's input tokens (250) before it
3548 thread.read_with(cx, |thread, _| {
3549 assert_eq!(
3550 thread.tokens_before_message(&message_3_id),
3551 Some(250),
3552 "Third message should have 250 tokens before it (from second request)"
3553 );
3554 // Second message should still have 100
3555 assert_eq!(
3556 thread.tokens_before_message(&message_2_id),
3557 Some(100),
3558 "Second message should still have 100 tokens before it"
3559 );
3560 // First message still has none
3561 assert_eq!(
3562 thread.tokens_before_message(&message_1_id),
3563 None,
3564 "First message should still have no tokens before it"
3565 );
3566 });
3567}
3568
3569#[gpui::test]
3570async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3571 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3572 let fake_model = model.as_fake();
3573
3574 // Set up three messages with responses
3575 let message_1_id = UserMessageId::new();
3576 thread
3577 .update(cx, |thread, cx| {
3578 thread.send(message_1_id.clone(), ["Message 1"], cx)
3579 })
3580 .unwrap();
3581 cx.run_until_parked();
3582 fake_model.send_last_completion_stream_text_chunk("Response 1");
3583 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3584 language_model::TokenUsage {
3585 input_tokens: 100,
3586 output_tokens: 50,
3587 cache_creation_input_tokens: 0,
3588 cache_read_input_tokens: 0,
3589 },
3590 ));
3591 fake_model.end_last_completion_stream();
3592 cx.run_until_parked();
3593
3594 let message_2_id = UserMessageId::new();
3595 thread
3596 .update(cx, |thread, cx| {
3597 thread.send(message_2_id.clone(), ["Message 2"], cx)
3598 })
3599 .unwrap();
3600 cx.run_until_parked();
3601 fake_model.send_last_completion_stream_text_chunk("Response 2");
3602 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3603 language_model::TokenUsage {
3604 input_tokens: 250,
3605 output_tokens: 75,
3606 cache_creation_input_tokens: 0,
3607 cache_read_input_tokens: 0,
3608 },
3609 ));
3610 fake_model.end_last_completion_stream();
3611 cx.run_until_parked();
3612
3613 // Verify initial state
3614 thread.read_with(cx, |thread, _| {
3615 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3616 });
3617
3618 // Truncate at message 2 (removes message 2 and everything after)
3619 thread
3620 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3621 .unwrap();
3622 cx.run_until_parked();
3623
3624 // After truncation, message_2_id no longer exists, so lookup should return None
3625 thread.read_with(cx, |thread, _| {
3626 assert_eq!(
3627 thread.tokens_before_message(&message_2_id),
3628 None,
3629 "After truncation, message 2 no longer exists"
3630 );
3631 // Message 1 still exists but has no tokens before it
3632 assert_eq!(
3633 thread.tokens_before_message(&message_1_id),
3634 None,
3635 "First message still has no tokens before it"
3636 );
3637 });
3638}
3639
3640#[gpui::test]
3641async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3642 init_test(cx);
3643
3644 let fs = FakeFs::new(cx.executor());
3645 fs.insert_tree("/root", json!({})).await;
3646 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3647
3648 // Test 1: Deny rule blocks command
3649 {
3650 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3651 let environment = Rc::new(FakeThreadEnvironment {
3652 handle: handle.clone(),
3653 });
3654
3655 cx.update(|cx| {
3656 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3657 settings.tool_permissions.tools.insert(
3658 "terminal".into(),
3659 agent_settings::ToolRules {
3660 default_mode: settings::ToolPermissionMode::Confirm,
3661 always_allow: vec![],
3662 always_deny: vec![
3663 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3664 ],
3665 always_confirm: vec![],
3666 invalid_patterns: vec![],
3667 },
3668 );
3669 agent_settings::AgentSettings::override_global(settings, cx);
3670 });
3671
3672 #[allow(clippy::arc_with_non_send_sync)]
3673 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3674 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3675
3676 let task = cx.update(|cx| {
3677 tool.run(
3678 crate::TerminalToolInput {
3679 command: "rm -rf /".to_string(),
3680 cd: ".".to_string(),
3681 timeout_ms: None,
3682 },
3683 event_stream,
3684 cx,
3685 )
3686 });
3687
3688 let result = task.await;
3689 assert!(
3690 result.is_err(),
3691 "expected command to be blocked by deny rule"
3692 );
3693 assert!(
3694 result.unwrap_err().to_string().contains("blocked"),
3695 "error should mention the command was blocked"
3696 );
3697 }
3698
3699 // Test 2: Allow rule skips confirmation (and overrides default_mode: Deny)
3700 {
3701 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3702 let environment = Rc::new(FakeThreadEnvironment {
3703 handle: handle.clone(),
3704 });
3705
3706 cx.update(|cx| {
3707 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3708 settings.always_allow_tool_actions = false;
3709 settings.tool_permissions.tools.insert(
3710 "terminal".into(),
3711 agent_settings::ToolRules {
3712 default_mode: settings::ToolPermissionMode::Deny,
3713 always_allow: vec![
3714 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
3715 ],
3716 always_deny: vec![],
3717 always_confirm: vec![],
3718 invalid_patterns: vec![],
3719 },
3720 );
3721 agent_settings::AgentSettings::override_global(settings, cx);
3722 });
3723
3724 #[allow(clippy::arc_with_non_send_sync)]
3725 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3726 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3727
3728 let task = cx.update(|cx| {
3729 tool.run(
3730 crate::TerminalToolInput {
3731 command: "echo hello".to_string(),
3732 cd: ".".to_string(),
3733 timeout_ms: None,
3734 },
3735 event_stream,
3736 cx,
3737 )
3738 });
3739
3740 let update = rx.expect_update_fields().await;
3741 assert!(
3742 update.content.iter().any(|blocks| {
3743 blocks
3744 .iter()
3745 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
3746 }),
3747 "expected terminal content (allow rule should skip confirmation and override default deny)"
3748 );
3749
3750 let result = task.await;
3751 assert!(
3752 result.is_ok(),
3753 "expected command to succeed without confirmation"
3754 );
3755 }
3756
3757 // Test 3: always_allow_tool_actions=true overrides always_confirm patterns
3758 {
3759 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3760 let environment = Rc::new(FakeThreadEnvironment {
3761 handle: handle.clone(),
3762 });
3763
3764 cx.update(|cx| {
3765 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3766 settings.always_allow_tool_actions = true;
3767 settings.tool_permissions.tools.insert(
3768 "terminal".into(),
3769 agent_settings::ToolRules {
3770 default_mode: settings::ToolPermissionMode::Allow,
3771 always_allow: vec![],
3772 always_deny: vec![],
3773 always_confirm: vec![
3774 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
3775 ],
3776 invalid_patterns: vec![],
3777 },
3778 );
3779 agent_settings::AgentSettings::override_global(settings, cx);
3780 });
3781
3782 #[allow(clippy::arc_with_non_send_sync)]
3783 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3784 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3785
3786 let task = cx.update(|cx| {
3787 tool.run(
3788 crate::TerminalToolInput {
3789 command: "sudo rm file".to_string(),
3790 cd: ".".to_string(),
3791 timeout_ms: None,
3792 },
3793 event_stream,
3794 cx,
3795 )
3796 });
3797
3798 // With always_allow_tool_actions=true, confirm patterns are overridden
3799 task.await
3800 .expect("command should be allowed with always_allow_tool_actions=true");
3801 }
3802
3803 // Test 4: always_allow_tool_actions=true overrides default_mode: Deny
3804 {
3805 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3806 let environment = Rc::new(FakeThreadEnvironment {
3807 handle: handle.clone(),
3808 });
3809
3810 cx.update(|cx| {
3811 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3812 settings.always_allow_tool_actions = true;
3813 settings.tool_permissions.tools.insert(
3814 "terminal".into(),
3815 agent_settings::ToolRules {
3816 default_mode: settings::ToolPermissionMode::Deny,
3817 always_allow: vec![],
3818 always_deny: vec![],
3819 always_confirm: vec![],
3820 invalid_patterns: vec![],
3821 },
3822 );
3823 agent_settings::AgentSettings::override_global(settings, cx);
3824 });
3825
3826 #[allow(clippy::arc_with_non_send_sync)]
3827 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3828 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3829
3830 let task = cx.update(|cx| {
3831 tool.run(
3832 crate::TerminalToolInput {
3833 command: "echo hello".to_string(),
3834 cd: ".".to_string(),
3835 timeout_ms: None,
3836 },
3837 event_stream,
3838 cx,
3839 )
3840 });
3841
3842 // With always_allow_tool_actions=true, even default_mode: Deny is overridden
3843 task.await
3844 .expect("command should be allowed with always_allow_tool_actions=true");
3845 }
3846}
3847
3848#[gpui::test]
3849async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
3850 init_test(cx);
3851
3852 cx.update(|cx| {
3853 cx.update_flags(true, vec!["subagents".to_string()]);
3854 });
3855
3856 let fs = FakeFs::new(cx.executor());
3857 fs.insert_tree(path!("/test"), json!({})).await;
3858 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3859 let project_context = cx.new(|_cx| ProjectContext::default());
3860 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3861 let context_server_registry =
3862 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3863 let model = Arc::new(FakeLanguageModel::default());
3864
3865 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3866 let environment = Rc::new(FakeThreadEnvironment { handle });
3867
3868 let thread = cx.new(|cx| {
3869 let mut thread = Thread::new(
3870 project.clone(),
3871 project_context,
3872 context_server_registry,
3873 Templates::new(),
3874 Some(model),
3875 cx,
3876 );
3877 thread.add_default_tools(environment, cx);
3878 thread
3879 });
3880
3881 thread.read_with(cx, |thread, _| {
3882 assert!(
3883 thread.has_registered_tool("subagent"),
3884 "subagent tool should be present when feature flag is enabled"
3885 );
3886 });
3887}
3888
3889#[gpui::test]
3890async fn test_subagent_thread_inherits_parent_model(cx: &mut TestAppContext) {
3891 init_test(cx);
3892
3893 cx.update(|cx| {
3894 cx.update_flags(true, vec!["subagents".to_string()]);
3895 });
3896
3897 let fs = FakeFs::new(cx.executor());
3898 fs.insert_tree(path!("/test"), json!({})).await;
3899 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3900 let project_context = cx.new(|_cx| ProjectContext::default());
3901 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3902 let context_server_registry =
3903 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3904 let model = Arc::new(FakeLanguageModel::default());
3905
3906 let subagent_context = SubagentContext {
3907 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3908 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3909 depth: 1,
3910 summary_prompt: "Summarize".to_string(),
3911 context_low_prompt: "Context low".to_string(),
3912 };
3913
3914 let subagent = cx.new(|cx| {
3915 Thread::new_subagent(
3916 project.clone(),
3917 project_context,
3918 context_server_registry,
3919 Templates::new(),
3920 model.clone(),
3921 subagent_context,
3922 std::collections::BTreeMap::new(),
3923 cx,
3924 )
3925 });
3926
3927 subagent.read_with(cx, |thread, _| {
3928 assert!(thread.is_subagent());
3929 assert_eq!(thread.depth(), 1);
3930 assert!(thread.model().is_some());
3931 });
3932}
3933
3934#[gpui::test]
3935async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
3936 init_test(cx);
3937
3938 cx.update(|cx| {
3939 cx.update_flags(true, vec!["subagents".to_string()]);
3940 });
3941
3942 let fs = FakeFs::new(cx.executor());
3943 fs.insert_tree(path!("/test"), json!({})).await;
3944 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3945 let project_context = cx.new(|_cx| ProjectContext::default());
3946 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3947 let context_server_registry =
3948 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3949 let model = Arc::new(FakeLanguageModel::default());
3950
3951 let subagent_context = SubagentContext {
3952 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3953 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3954 depth: MAX_SUBAGENT_DEPTH,
3955 summary_prompt: "Summarize".to_string(),
3956 context_low_prompt: "Context low".to_string(),
3957 };
3958
3959 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3960 let environment = Rc::new(FakeThreadEnvironment { handle });
3961
3962 let deep_subagent = cx.new(|cx| {
3963 let mut thread = Thread::new_subagent(
3964 project.clone(),
3965 project_context,
3966 context_server_registry,
3967 Templates::new(),
3968 model.clone(),
3969 subagent_context,
3970 std::collections::BTreeMap::new(),
3971 cx,
3972 );
3973 thread.add_default_tools(environment, cx);
3974 thread
3975 });
3976
3977 deep_subagent.read_with(cx, |thread, _| {
3978 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
3979 assert!(
3980 !thread.has_registered_tool("subagent"),
3981 "subagent tool should not be present at max depth"
3982 );
3983 });
3984}
3985
3986#[gpui::test]
3987async fn test_subagent_receives_task_prompt(cx: &mut TestAppContext) {
3988 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3989 let fake_model = model.as_fake();
3990
3991 cx.update(|cx| {
3992 cx.update_flags(true, vec!["subagents".to_string()]);
3993 });
3994
3995 let subagent_context = SubagentContext {
3996 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3997 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3998 depth: 1,
3999 summary_prompt: "Summarize your work".to_string(),
4000 context_low_prompt: "Context low, wrap up".to_string(),
4001 };
4002
4003 let project = thread.read_with(cx, |t, _| t.project.clone());
4004 let project_context = cx.new(|_cx| ProjectContext::default());
4005 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4006 let context_server_registry =
4007 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4008
4009 let subagent = cx.new(|cx| {
4010 Thread::new_subagent(
4011 project.clone(),
4012 project_context,
4013 context_server_registry,
4014 Templates::new(),
4015 model.clone(),
4016 subagent_context,
4017 std::collections::BTreeMap::new(),
4018 cx,
4019 )
4020 });
4021
4022 let task_prompt = "Find all TODO comments in the codebase";
4023 subagent
4024 .update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))
4025 .unwrap();
4026 cx.run_until_parked();
4027
4028 let pending = fake_model.pending_completions();
4029 assert_eq!(pending.len(), 1, "should have one pending completion");
4030
4031 let messages = &pending[0].messages;
4032 let user_messages: Vec<_> = messages
4033 .iter()
4034 .filter(|m| m.role == language_model::Role::User)
4035 .collect();
4036 assert_eq!(user_messages.len(), 1, "should have one user message");
4037
4038 let content = &user_messages[0].content[0];
4039 assert!(
4040 content.to_str().unwrap().contains("TODO"),
4041 "task prompt should be in user message"
4042 );
4043}
4044
4045#[gpui::test]
4046async fn test_subagent_returns_summary_on_completion(cx: &mut TestAppContext) {
4047 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4048 let fake_model = model.as_fake();
4049
4050 cx.update(|cx| {
4051 cx.update_flags(true, vec!["subagents".to_string()]);
4052 });
4053
4054 let subagent_context = SubagentContext {
4055 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4056 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4057 depth: 1,
4058 summary_prompt: "Please summarize what you found".to_string(),
4059 context_low_prompt: "Context low, wrap up".to_string(),
4060 };
4061
4062 let project = thread.read_with(cx, |t, _| t.project.clone());
4063 let project_context = cx.new(|_cx| ProjectContext::default());
4064 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4065 let context_server_registry =
4066 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4067
4068 let subagent = cx.new(|cx| {
4069 Thread::new_subagent(
4070 project.clone(),
4071 project_context,
4072 context_server_registry,
4073 Templates::new(),
4074 model.clone(),
4075 subagent_context,
4076 std::collections::BTreeMap::new(),
4077 cx,
4078 )
4079 });
4080
4081 subagent
4082 .update(cx, |thread, cx| {
4083 thread.submit_user_message("Do some work", cx)
4084 })
4085 .unwrap();
4086 cx.run_until_parked();
4087
4088 fake_model.send_last_completion_stream_text_chunk("I did the work");
4089 fake_model
4090 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4091 fake_model.end_last_completion_stream();
4092 cx.run_until_parked();
4093
4094 subagent
4095 .update(cx, |thread, cx| thread.request_final_summary(cx))
4096 .unwrap();
4097 cx.run_until_parked();
4098
4099 let pending = fake_model.pending_completions();
4100 assert!(
4101 !pending.is_empty(),
4102 "should have pending completion for summary"
4103 );
4104
4105 let messages = &pending.last().unwrap().messages;
4106 let user_messages: Vec<_> = messages
4107 .iter()
4108 .filter(|m| m.role == language_model::Role::User)
4109 .collect();
4110
4111 let last_user = user_messages.last().unwrap();
4112 assert!(
4113 last_user.content[0].to_str().unwrap().contains("summarize"),
4114 "summary prompt should be sent"
4115 );
4116}
4117
4118#[gpui::test]
4119async fn test_allowed_tools_restricts_subagent_capabilities(cx: &mut TestAppContext) {
4120 init_test(cx);
4121
4122 cx.update(|cx| {
4123 cx.update_flags(true, vec!["subagents".to_string()]);
4124 });
4125
4126 let fs = FakeFs::new(cx.executor());
4127 fs.insert_tree(path!("/test"), json!({})).await;
4128 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4129 let project_context = cx.new(|_cx| ProjectContext::default());
4130 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4131 let context_server_registry =
4132 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4133 let model = Arc::new(FakeLanguageModel::default());
4134
4135 let subagent_context = SubagentContext {
4136 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4137 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4138 depth: 1,
4139 summary_prompt: "Summarize".to_string(),
4140 context_low_prompt: "Context low".to_string(),
4141 };
4142
4143 let subagent = cx.new(|cx| {
4144 let mut thread = Thread::new_subagent(
4145 project.clone(),
4146 project_context,
4147 context_server_registry,
4148 Templates::new(),
4149 model.clone(),
4150 subagent_context,
4151 std::collections::BTreeMap::new(),
4152 cx,
4153 );
4154 thread.add_tool(EchoTool);
4155 thread.add_tool(DelayTool);
4156 thread.add_tool(WordListTool);
4157 thread
4158 });
4159
4160 subagent.read_with(cx, |thread, _| {
4161 assert!(thread.has_registered_tool("echo"));
4162 assert!(thread.has_registered_tool("delay"));
4163 assert!(thread.has_registered_tool("word_list"));
4164 });
4165
4166 let allowed: collections::HashSet<gpui::SharedString> =
4167 vec!["echo".into()].into_iter().collect();
4168
4169 subagent.update(cx, |thread, _cx| {
4170 thread.restrict_tools(&allowed);
4171 });
4172
4173 subagent.read_with(cx, |thread, _| {
4174 assert!(
4175 thread.has_registered_tool("echo"),
4176 "echo should still be available"
4177 );
4178 assert!(
4179 !thread.has_registered_tool("delay"),
4180 "delay should be removed"
4181 );
4182 assert!(
4183 !thread.has_registered_tool("word_list"),
4184 "word_list should be removed"
4185 );
4186 });
4187}
4188
4189#[gpui::test]
4190async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4191 init_test(cx);
4192
4193 cx.update(|cx| {
4194 cx.update_flags(true, vec!["subagents".to_string()]);
4195 });
4196
4197 let fs = FakeFs::new(cx.executor());
4198 fs.insert_tree(path!("/test"), json!({})).await;
4199 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4200 let project_context = cx.new(|_cx| ProjectContext::default());
4201 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4202 let context_server_registry =
4203 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4204 let model = Arc::new(FakeLanguageModel::default());
4205
4206 let parent = cx.new(|cx| {
4207 Thread::new(
4208 project.clone(),
4209 project_context.clone(),
4210 context_server_registry.clone(),
4211 Templates::new(),
4212 Some(model.clone()),
4213 cx,
4214 )
4215 });
4216
4217 let subagent_context = SubagentContext {
4218 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4219 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4220 depth: 1,
4221 summary_prompt: "Summarize".to_string(),
4222 context_low_prompt: "Context low".to_string(),
4223 };
4224
4225 let subagent = cx.new(|cx| {
4226 Thread::new_subagent(
4227 project.clone(),
4228 project_context.clone(),
4229 context_server_registry.clone(),
4230 Templates::new(),
4231 model.clone(),
4232 subagent_context,
4233 std::collections::BTreeMap::new(),
4234 cx,
4235 )
4236 });
4237
4238 parent.update(cx, |thread, _cx| {
4239 thread.register_running_subagent(subagent.downgrade());
4240 });
4241
4242 subagent
4243 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4244 .unwrap();
4245 cx.run_until_parked();
4246
4247 subagent.read_with(cx, |thread, _| {
4248 assert!(!thread.is_turn_complete(), "subagent should be running");
4249 });
4250
4251 parent.update(cx, |thread, cx| {
4252 thread.cancel(cx).detach();
4253 });
4254
4255 subagent.read_with(cx, |thread, _| {
4256 assert!(
4257 thread.is_turn_complete(),
4258 "subagent should be cancelled when parent cancels"
4259 );
4260 });
4261}
4262
4263#[gpui::test]
4264async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4265 // This test verifies that the subagent tool properly handles user cancellation
4266 // via `event_stream.cancelled_by_user()` and stops all running subagents.
4267 init_test(cx);
4268 always_allow_tools(cx);
4269
4270 cx.update(|cx| {
4271 cx.update_flags(true, vec!["subagents".to_string()]);
4272 });
4273
4274 let fs = FakeFs::new(cx.executor());
4275 fs.insert_tree(path!("/test"), json!({})).await;
4276 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4277 let project_context = cx.new(|_cx| ProjectContext::default());
4278 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4279 let context_server_registry =
4280 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4281 let model = Arc::new(FakeLanguageModel::default());
4282
4283 let parent = cx.new(|cx| {
4284 Thread::new(
4285 project.clone(),
4286 project_context.clone(),
4287 context_server_registry.clone(),
4288 Templates::new(),
4289 Some(model.clone()),
4290 cx,
4291 )
4292 });
4293
4294 let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
4295 std::collections::BTreeMap::new();
4296
4297 #[allow(clippy::arc_with_non_send_sync)]
4298 let tool = Arc::new(SubagentTool::new(
4299 parent.downgrade(),
4300 project.clone(),
4301 project_context,
4302 context_server_registry,
4303 Templates::new(),
4304 0,
4305 parent_tools,
4306 ));
4307
4308 let (event_stream, _rx, mut cancellation_tx) =
4309 crate::ToolCallEventStream::test_with_cancellation();
4310
4311 // Start the subagent tool
4312 let task = cx.update(|cx| {
4313 tool.run(
4314 SubagentToolInput {
4315 subagents: vec![crate::SubagentConfig {
4316 label: "Long running task".to_string(),
4317 task_prompt: "Do a very long task that takes forever".to_string(),
4318 summary_prompt: "Summarize".to_string(),
4319 context_low_prompt: "Context low".to_string(),
4320 timeout_ms: None,
4321 allowed_tools: None,
4322 }],
4323 },
4324 event_stream.clone(),
4325 cx,
4326 )
4327 });
4328
4329 cx.run_until_parked();
4330
4331 // Signal cancellation via the event stream
4332 crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4333
4334 // The task should complete promptly with a cancellation error
4335 let timeout = cx.background_executor.timer(Duration::from_secs(5));
4336 let result = futures::select! {
4337 result = task.fuse() => result,
4338 _ = timeout.fuse() => {
4339 panic!("subagent tool did not respond to cancellation within timeout");
4340 }
4341 };
4342
4343 // Verify we got a cancellation error
4344 let err = result.unwrap_err();
4345 assert!(
4346 err.to_string().contains("cancelled by user"),
4347 "expected cancellation error, got: {}",
4348 err
4349 );
4350}
4351
4352#[gpui::test]
4353async fn test_subagent_model_error_returned_as_tool_error(cx: &mut TestAppContext) {
4354 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4355 let fake_model = model.as_fake();
4356
4357 cx.update(|cx| {
4358 cx.update_flags(true, vec!["subagents".to_string()]);
4359 });
4360
4361 let subagent_context = SubagentContext {
4362 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4363 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4364 depth: 1,
4365 summary_prompt: "Summarize".to_string(),
4366 context_low_prompt: "Context low".to_string(),
4367 };
4368
4369 let project = thread.read_with(cx, |t, _| t.project.clone());
4370 let project_context = cx.new(|_cx| ProjectContext::default());
4371 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4372 let context_server_registry =
4373 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4374
4375 let subagent = cx.new(|cx| {
4376 Thread::new_subagent(
4377 project.clone(),
4378 project_context,
4379 context_server_registry,
4380 Templates::new(),
4381 model.clone(),
4382 subagent_context,
4383 std::collections::BTreeMap::new(),
4384 cx,
4385 )
4386 });
4387
4388 subagent
4389 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4390 .unwrap();
4391 cx.run_until_parked();
4392
4393 subagent.read_with(cx, |thread, _| {
4394 assert!(!thread.is_turn_complete(), "turn should be in progress");
4395 });
4396
4397 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::NoApiKey {
4398 provider: LanguageModelProviderName::from("Fake".to_string()),
4399 });
4400 fake_model.end_last_completion_stream();
4401 cx.run_until_parked();
4402
4403 subagent.read_with(cx, |thread, _| {
4404 assert!(
4405 thread.is_turn_complete(),
4406 "turn should be complete after non-retryable error"
4407 );
4408 });
4409}
4410
4411#[gpui::test]
4412async fn test_subagent_timeout_triggers_early_summary(cx: &mut TestAppContext) {
4413 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4414 let fake_model = model.as_fake();
4415
4416 cx.update(|cx| {
4417 cx.update_flags(true, vec!["subagents".to_string()]);
4418 });
4419
4420 let subagent_context = SubagentContext {
4421 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4422 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4423 depth: 1,
4424 summary_prompt: "Summarize your work".to_string(),
4425 context_low_prompt: "Context low, stop and summarize".to_string(),
4426 };
4427
4428 let project = thread.read_with(cx, |t, _| t.project.clone());
4429 let project_context = cx.new(|_cx| ProjectContext::default());
4430 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4431 let context_server_registry =
4432 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4433
4434 let subagent = cx.new(|cx| {
4435 Thread::new_subagent(
4436 project.clone(),
4437 project_context.clone(),
4438 context_server_registry.clone(),
4439 Templates::new(),
4440 model.clone(),
4441 subagent_context.clone(),
4442 std::collections::BTreeMap::new(),
4443 cx,
4444 )
4445 });
4446
4447 subagent.update(cx, |thread, _| {
4448 thread.add_tool(EchoTool);
4449 });
4450
4451 subagent
4452 .update(cx, |thread, cx| {
4453 thread.submit_user_message("Do some work", cx)
4454 })
4455 .unwrap();
4456 cx.run_until_parked();
4457
4458 fake_model.send_last_completion_stream_text_chunk("Working on it...");
4459 fake_model
4460 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4461 fake_model.end_last_completion_stream();
4462 cx.run_until_parked();
4463
4464 let interrupt_result = subagent.update(cx, |thread, cx| thread.interrupt_for_summary(cx));
4465 assert!(
4466 interrupt_result.is_ok(),
4467 "interrupt_for_summary should succeed"
4468 );
4469
4470 cx.run_until_parked();
4471
4472 let pending = fake_model.pending_completions();
4473 assert!(
4474 !pending.is_empty(),
4475 "should have pending completion for interrupted summary"
4476 );
4477
4478 let messages = &pending.last().unwrap().messages;
4479 let user_messages: Vec<_> = messages
4480 .iter()
4481 .filter(|m| m.role == language_model::Role::User)
4482 .collect();
4483
4484 let last_user = user_messages.last().unwrap();
4485 let content_str = last_user.content[0].to_str().unwrap();
4486 assert!(
4487 content_str.contains("Context low") || content_str.contains("stop and summarize"),
4488 "context_low_prompt should be sent when interrupting: got {:?}",
4489 content_str
4490 );
4491}
4492
4493#[gpui::test]
4494async fn test_context_low_check_returns_true_when_usage_high(cx: &mut TestAppContext) {
4495 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4496 let fake_model = model.as_fake();
4497
4498 cx.update(|cx| {
4499 cx.update_flags(true, vec!["subagents".to_string()]);
4500 });
4501
4502 let subagent_context = SubagentContext {
4503 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4504 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4505 depth: 1,
4506 summary_prompt: "Summarize".to_string(),
4507 context_low_prompt: "Context low".to_string(),
4508 };
4509
4510 let project = thread.read_with(cx, |t, _| t.project.clone());
4511 let project_context = cx.new(|_cx| ProjectContext::default());
4512 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4513 let context_server_registry =
4514 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4515
4516 let subagent = cx.new(|cx| {
4517 Thread::new_subagent(
4518 project.clone(),
4519 project_context,
4520 context_server_registry,
4521 Templates::new(),
4522 model.clone(),
4523 subagent_context,
4524 std::collections::BTreeMap::new(),
4525 cx,
4526 )
4527 });
4528
4529 subagent
4530 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4531 .unwrap();
4532 cx.run_until_parked();
4533
4534 let max_tokens = model.max_token_count();
4535 let high_usage = language_model::TokenUsage {
4536 input_tokens: (max_tokens as f64 * 0.80) as u64,
4537 output_tokens: 0,
4538 cache_creation_input_tokens: 0,
4539 cache_read_input_tokens: 0,
4540 };
4541
4542 fake_model
4543 .send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(high_usage));
4544 fake_model.send_last_completion_stream_text_chunk("Working...");
4545 fake_model
4546 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4547 fake_model.end_last_completion_stream();
4548 cx.run_until_parked();
4549
4550 let usage = subagent.read_with(cx, |thread, _| thread.latest_token_usage());
4551 assert!(usage.is_some(), "should have token usage after completion");
4552
4553 let usage = usage.unwrap();
4554 let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
4555 assert!(
4556 remaining_ratio <= 0.25,
4557 "remaining ratio should be at or below 25% (got {}%), indicating context is low",
4558 remaining_ratio * 100.0
4559 );
4560}
4561
4562#[gpui::test]
4563async fn test_allowed_tools_rejects_unknown_tool(cx: &mut TestAppContext) {
4564 init_test(cx);
4565
4566 cx.update(|cx| {
4567 cx.update_flags(true, vec!["subagents".to_string()]);
4568 });
4569
4570 let fs = FakeFs::new(cx.executor());
4571 fs.insert_tree(path!("/test"), json!({})).await;
4572 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4573 let project_context = cx.new(|_cx| ProjectContext::default());
4574 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4575 let context_server_registry =
4576 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4577 let model = Arc::new(FakeLanguageModel::default());
4578
4579 let parent = cx.new(|cx| {
4580 let mut thread = Thread::new(
4581 project.clone(),
4582 project_context.clone(),
4583 context_server_registry.clone(),
4584 Templates::new(),
4585 Some(model.clone()),
4586 cx,
4587 );
4588 thread.add_tool(EchoTool);
4589 thread
4590 });
4591
4592 let mut parent_tools: std::collections::BTreeMap<
4593 gpui::SharedString,
4594 Arc<dyn crate::AnyAgentTool>,
4595 > = std::collections::BTreeMap::new();
4596 parent_tools.insert("echo".into(), EchoTool.erase());
4597
4598 #[allow(clippy::arc_with_non_send_sync)]
4599 let tool = Arc::new(SubagentTool::new(
4600 parent.downgrade(),
4601 project,
4602 project_context,
4603 context_server_registry,
4604 Templates::new(),
4605 0,
4606 parent_tools,
4607 ));
4608
4609 let subagent_configs = vec![crate::SubagentConfig {
4610 label: "Test".to_string(),
4611 task_prompt: "Do something".to_string(),
4612 summary_prompt: "Summarize".to_string(),
4613 context_low_prompt: "Context low".to_string(),
4614 timeout_ms: None,
4615 allowed_tools: Some(vec!["nonexistent_tool".to_string()]),
4616 }];
4617 let result = tool.validate_subagents(&subagent_configs);
4618 assert!(result.is_err(), "should reject unknown tool");
4619 let err_msg = result.unwrap_err().to_string();
4620 assert!(
4621 err_msg.contains("nonexistent_tool"),
4622 "error should mention the invalid tool name: {}",
4623 err_msg
4624 );
4625 assert!(
4626 err_msg.contains("do not exist"),
4627 "error should explain the tool does not exist: {}",
4628 err_msg
4629 );
4630}
4631
4632#[gpui::test]
4633async fn test_subagent_empty_response_handled(cx: &mut TestAppContext) {
4634 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4635 let fake_model = model.as_fake();
4636
4637 cx.update(|cx| {
4638 cx.update_flags(true, vec!["subagents".to_string()]);
4639 });
4640
4641 let subagent_context = SubagentContext {
4642 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4643 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4644 depth: 1,
4645 summary_prompt: "Summarize".to_string(),
4646 context_low_prompt: "Context low".to_string(),
4647 };
4648
4649 let project = thread.read_with(cx, |t, _| t.project.clone());
4650 let project_context = cx.new(|_cx| ProjectContext::default());
4651 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4652 let context_server_registry =
4653 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4654
4655 let subagent = cx.new(|cx| {
4656 Thread::new_subagent(
4657 project.clone(),
4658 project_context,
4659 context_server_registry,
4660 Templates::new(),
4661 model.clone(),
4662 subagent_context,
4663 std::collections::BTreeMap::new(),
4664 cx,
4665 )
4666 });
4667
4668 subagent
4669 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4670 .unwrap();
4671 cx.run_until_parked();
4672
4673 fake_model
4674 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4675 fake_model.end_last_completion_stream();
4676 cx.run_until_parked();
4677
4678 subagent.read_with(cx, |thread, _| {
4679 assert!(
4680 thread.is_turn_complete(),
4681 "turn should complete even with empty response"
4682 );
4683 });
4684}
4685
4686#[gpui::test]
4687async fn test_nested_subagent_at_depth_2_succeeds(cx: &mut TestAppContext) {
4688 init_test(cx);
4689
4690 cx.update(|cx| {
4691 cx.update_flags(true, vec!["subagents".to_string()]);
4692 });
4693
4694 let fs = FakeFs::new(cx.executor());
4695 fs.insert_tree(path!("/test"), json!({})).await;
4696 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4697 let project_context = cx.new(|_cx| ProjectContext::default());
4698 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4699 let context_server_registry =
4700 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4701 let model = Arc::new(FakeLanguageModel::default());
4702
4703 let depth_1_context = SubagentContext {
4704 parent_thread_id: agent_client_protocol::SessionId::new("root-id"),
4705 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-1"),
4706 depth: 1,
4707 summary_prompt: "Summarize".to_string(),
4708 context_low_prompt: "Context low".to_string(),
4709 };
4710
4711 let depth_1_subagent = cx.new(|cx| {
4712 Thread::new_subagent(
4713 project.clone(),
4714 project_context.clone(),
4715 context_server_registry.clone(),
4716 Templates::new(),
4717 model.clone(),
4718 depth_1_context,
4719 std::collections::BTreeMap::new(),
4720 cx,
4721 )
4722 });
4723
4724 depth_1_subagent.read_with(cx, |thread, _| {
4725 assert_eq!(thread.depth(), 1);
4726 assert!(thread.is_subagent());
4727 });
4728
4729 let depth_2_context = SubagentContext {
4730 parent_thread_id: agent_client_protocol::SessionId::new("depth-1-id"),
4731 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-2"),
4732 depth: 2,
4733 summary_prompt: "Summarize depth 2".to_string(),
4734 context_low_prompt: "Context low depth 2".to_string(),
4735 };
4736
4737 let depth_2_subagent = cx.new(|cx| {
4738 Thread::new_subagent(
4739 project.clone(),
4740 project_context.clone(),
4741 context_server_registry.clone(),
4742 Templates::new(),
4743 model.clone(),
4744 depth_2_context,
4745 std::collections::BTreeMap::new(),
4746 cx,
4747 )
4748 });
4749
4750 depth_2_subagent.read_with(cx, |thread, _| {
4751 assert_eq!(thread.depth(), 2);
4752 assert!(thread.is_subagent());
4753 });
4754
4755 depth_2_subagent
4756 .update(cx, |thread, cx| {
4757 thread.submit_user_message("Nested task", cx)
4758 })
4759 .unwrap();
4760 cx.run_until_parked();
4761
4762 let pending = model.as_fake().pending_completions();
4763 assert!(
4764 !pending.is_empty(),
4765 "depth-2 subagent should be able to submit messages"
4766 );
4767}
4768
4769#[gpui::test]
4770async fn test_subagent_uses_tool_and_returns_result(cx: &mut TestAppContext) {
4771 init_test(cx);
4772 always_allow_tools(cx);
4773
4774 cx.update(|cx| {
4775 cx.update_flags(true, vec!["subagents".to_string()]);
4776 });
4777
4778 let fs = FakeFs::new(cx.executor());
4779 fs.insert_tree(path!("/test"), json!({})).await;
4780 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4781 let project_context = cx.new(|_cx| ProjectContext::default());
4782 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4783 let context_server_registry =
4784 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4785 let model = Arc::new(FakeLanguageModel::default());
4786 let fake_model = model.as_fake();
4787
4788 let subagent_context = SubagentContext {
4789 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4790 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4791 depth: 1,
4792 summary_prompt: "Summarize what you did".to_string(),
4793 context_low_prompt: "Context low".to_string(),
4794 };
4795
4796 let subagent = cx.new(|cx| {
4797 let mut thread = Thread::new_subagent(
4798 project.clone(),
4799 project_context,
4800 context_server_registry,
4801 Templates::new(),
4802 model.clone(),
4803 subagent_context,
4804 std::collections::BTreeMap::new(),
4805 cx,
4806 );
4807 thread.add_tool(EchoTool);
4808 thread
4809 });
4810
4811 subagent.read_with(cx, |thread, _| {
4812 assert!(
4813 thread.has_registered_tool("echo"),
4814 "subagent should have echo tool"
4815 );
4816 });
4817
4818 subagent
4819 .update(cx, |thread, cx| {
4820 thread.submit_user_message("Use the echo tool to echo 'hello world'", cx)
4821 })
4822 .unwrap();
4823 cx.run_until_parked();
4824
4825 let tool_use = LanguageModelToolUse {
4826 id: "tool_call_1".into(),
4827 name: EchoTool::name().into(),
4828 raw_input: json!({"text": "hello world"}).to_string(),
4829 input: json!({"text": "hello world"}),
4830 is_input_complete: true,
4831 thought_signature: None,
4832 };
4833 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
4834 fake_model.end_last_completion_stream();
4835 cx.run_until_parked();
4836
4837 let pending = fake_model.pending_completions();
4838 assert!(
4839 !pending.is_empty(),
4840 "should have pending completion after tool use"
4841 );
4842
4843 let last_completion = pending.last().unwrap();
4844 let has_tool_result = last_completion.messages.iter().any(|m| {
4845 m.content
4846 .iter()
4847 .any(|c| matches!(c, MessageContent::ToolResult(_)))
4848 });
4849 assert!(
4850 has_tool_result,
4851 "tool result should be in the messages sent back to the model"
4852 );
4853}
4854
4855#[gpui::test]
4856async fn test_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4857 init_test(cx);
4858
4859 cx.update(|cx| {
4860 cx.update_flags(true, vec!["subagents".to_string()]);
4861 });
4862
4863 let fs = FakeFs::new(cx.executor());
4864 fs.insert_tree(path!("/test"), json!({})).await;
4865 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4866 let project_context = cx.new(|_cx| ProjectContext::default());
4867 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4868 let context_server_registry =
4869 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4870 let model = Arc::new(FakeLanguageModel::default());
4871
4872 let parent = cx.new(|cx| {
4873 Thread::new(
4874 project.clone(),
4875 project_context.clone(),
4876 context_server_registry.clone(),
4877 Templates::new(),
4878 Some(model.clone()),
4879 cx,
4880 )
4881 });
4882
4883 let mut subagents = Vec::new();
4884 for i in 0..MAX_PARALLEL_SUBAGENTS {
4885 let subagent_context = SubagentContext {
4886 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4887 tool_use_id: language_model::LanguageModelToolUseId::from(format!("tool-use-{}", i)),
4888 depth: 1,
4889 summary_prompt: "Summarize".to_string(),
4890 context_low_prompt: "Context low".to_string(),
4891 };
4892
4893 let subagent = cx.new(|cx| {
4894 Thread::new_subagent(
4895 project.clone(),
4896 project_context.clone(),
4897 context_server_registry.clone(),
4898 Templates::new(),
4899 model.clone(),
4900 subagent_context,
4901 std::collections::BTreeMap::new(),
4902 cx,
4903 )
4904 });
4905
4906 parent.update(cx, |thread, _cx| {
4907 thread.register_running_subagent(subagent.downgrade());
4908 });
4909 subagents.push(subagent);
4910 }
4911
4912 parent.read_with(cx, |thread, _| {
4913 assert_eq!(
4914 thread.running_subagent_count(),
4915 MAX_PARALLEL_SUBAGENTS,
4916 "should have MAX_PARALLEL_SUBAGENTS registered"
4917 );
4918 });
4919
4920 let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
4921 std::collections::BTreeMap::new();
4922
4923 #[allow(clippy::arc_with_non_send_sync)]
4924 let tool = Arc::new(SubagentTool::new(
4925 parent.downgrade(),
4926 project.clone(),
4927 project_context,
4928 context_server_registry,
4929 Templates::new(),
4930 0,
4931 parent_tools,
4932 ));
4933
4934 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4935
4936 let result = cx.update(|cx| {
4937 tool.run(
4938 SubagentToolInput {
4939 subagents: vec![crate::SubagentConfig {
4940 label: "Test".to_string(),
4941 task_prompt: "Do something".to_string(),
4942 summary_prompt: "Summarize".to_string(),
4943 context_low_prompt: "Context low".to_string(),
4944 timeout_ms: None,
4945 allowed_tools: None,
4946 }],
4947 },
4948 event_stream,
4949 cx,
4950 )
4951 });
4952
4953 let err = result.await.unwrap_err();
4954 assert!(
4955 err.to_string().contains("Maximum parallel subagents"),
4956 "should reject when max parallel subagents reached: {}",
4957 err
4958 );
4959
4960 drop(subagents);
4961}
4962
4963#[gpui::test]
4964async fn test_subagent_tool_end_to_end(cx: &mut TestAppContext) {
4965 init_test(cx);
4966 always_allow_tools(cx);
4967
4968 cx.update(|cx| {
4969 cx.update_flags(true, vec!["subagents".to_string()]);
4970 });
4971
4972 let fs = FakeFs::new(cx.executor());
4973 fs.insert_tree(path!("/test"), json!({})).await;
4974 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4975 let project_context = cx.new(|_cx| ProjectContext::default());
4976 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4977 let context_server_registry =
4978 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4979 let model = Arc::new(FakeLanguageModel::default());
4980 let fake_model = model.as_fake();
4981
4982 let parent = cx.new(|cx| {
4983 let mut thread = Thread::new(
4984 project.clone(),
4985 project_context.clone(),
4986 context_server_registry.clone(),
4987 Templates::new(),
4988 Some(model.clone()),
4989 cx,
4990 );
4991 thread.add_tool(EchoTool);
4992 thread
4993 });
4994
4995 let mut parent_tools: std::collections::BTreeMap<
4996 gpui::SharedString,
4997 Arc<dyn crate::AnyAgentTool>,
4998 > = std::collections::BTreeMap::new();
4999 parent_tools.insert("echo".into(), EchoTool.erase());
5000
5001 #[allow(clippy::arc_with_non_send_sync)]
5002 let tool = Arc::new(SubagentTool::new(
5003 parent.downgrade(),
5004 project.clone(),
5005 project_context,
5006 context_server_registry,
5007 Templates::new(),
5008 0,
5009 parent_tools,
5010 ));
5011
5012 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5013
5014 let task = cx.update(|cx| {
5015 tool.run(
5016 SubagentToolInput {
5017 subagents: vec![crate::SubagentConfig {
5018 label: "Research task".to_string(),
5019 task_prompt: "Find all TODOs in the codebase".to_string(),
5020 summary_prompt: "Summarize what you found".to_string(),
5021 context_low_prompt: "Context low, wrap up".to_string(),
5022 timeout_ms: None,
5023 allowed_tools: None,
5024 }],
5025 },
5026 event_stream,
5027 cx,
5028 )
5029 });
5030
5031 cx.run_until_parked();
5032
5033 let pending = fake_model.pending_completions();
5034 assert!(
5035 !pending.is_empty(),
5036 "subagent should have started and sent a completion request"
5037 );
5038
5039 let first_completion = &pending[0];
5040 let has_task_prompt = first_completion.messages.iter().any(|m| {
5041 m.role == language_model::Role::User
5042 && m.content
5043 .iter()
5044 .any(|c| c.to_str().map(|s| s.contains("TODO")).unwrap_or(false))
5045 });
5046 assert!(has_task_prompt, "task prompt should be sent to subagent");
5047
5048 fake_model.send_last_completion_stream_text_chunk("I found 5 TODOs in the codebase.");
5049 fake_model
5050 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
5051 fake_model.end_last_completion_stream();
5052 cx.run_until_parked();
5053
5054 let pending = fake_model.pending_completions();
5055 assert!(
5056 !pending.is_empty(),
5057 "should have pending completion for summary request"
5058 );
5059
5060 let last_completion = pending.last().unwrap();
5061 let has_summary_prompt = last_completion.messages.iter().any(|m| {
5062 m.role == language_model::Role::User
5063 && m.content.iter().any(|c| {
5064 c.to_str()
5065 .map(|s| s.contains("Summarize") || s.contains("summarize"))
5066 .unwrap_or(false)
5067 })
5068 });
5069 assert!(
5070 has_summary_prompt,
5071 "summary prompt should be sent after task completion"
5072 );
5073
5074 fake_model.send_last_completion_stream_text_chunk("Summary: Found 5 TODOs across 3 files.");
5075 fake_model
5076 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
5077 fake_model.end_last_completion_stream();
5078 cx.run_until_parked();
5079
5080 let result = task.await;
5081 assert!(result.is_ok(), "subagent tool should complete successfully");
5082
5083 let summary = result.unwrap();
5084 assert!(
5085 summary.contains("Summary") || summary.contains("TODO") || summary.contains("5"),
5086 "summary should contain subagent's response: {}",
5087 summary
5088 );
5089}
5090
5091#[gpui::test]
5092async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5093 init_test(cx);
5094
5095 let fs = FakeFs::new(cx.executor());
5096 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5097 .await;
5098 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5099
5100 cx.update(|cx| {
5101 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5102 settings.tool_permissions.tools.insert(
5103 "edit_file".into(),
5104 agent_settings::ToolRules {
5105 default_mode: settings::ToolPermissionMode::Allow,
5106 always_allow: vec![],
5107 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5108 always_confirm: vec![],
5109 invalid_patterns: vec![],
5110 },
5111 );
5112 agent_settings::AgentSettings::override_global(settings, cx);
5113 });
5114
5115 let context_server_registry =
5116 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5117 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5118 let templates = crate::Templates::new();
5119 let thread = cx.new(|cx| {
5120 crate::Thread::new(
5121 project.clone(),
5122 cx.new(|_cx| prompt_store::ProjectContext::default()),
5123 context_server_registry,
5124 templates.clone(),
5125 None,
5126 cx,
5127 )
5128 });
5129
5130 #[allow(clippy::arc_with_non_send_sync)]
5131 let tool = Arc::new(crate::EditFileTool::new(
5132 project.clone(),
5133 thread.downgrade(),
5134 language_registry,
5135 templates,
5136 ));
5137 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5138
5139 let task = cx.update(|cx| {
5140 tool.run(
5141 crate::EditFileToolInput {
5142 display_description: "Edit sensitive file".to_string(),
5143 path: "root/sensitive_config.txt".into(),
5144 mode: crate::EditFileMode::Edit,
5145 },
5146 event_stream,
5147 cx,
5148 )
5149 });
5150
5151 let result = task.await;
5152 assert!(result.is_err(), "expected edit to be blocked");
5153 assert!(
5154 result.unwrap_err().to_string().contains("blocked"),
5155 "error should mention the edit was blocked"
5156 );
5157}
5158
5159#[gpui::test]
5160async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5161 init_test(cx);
5162
5163 let fs = FakeFs::new(cx.executor());
5164 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5165 .await;
5166 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5167
5168 cx.update(|cx| {
5169 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5170 settings.tool_permissions.tools.insert(
5171 "delete_path".into(),
5172 agent_settings::ToolRules {
5173 default_mode: settings::ToolPermissionMode::Allow,
5174 always_allow: vec![],
5175 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5176 always_confirm: vec![],
5177 invalid_patterns: vec![],
5178 },
5179 );
5180 agent_settings::AgentSettings::override_global(settings, cx);
5181 });
5182
5183 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5184
5185 #[allow(clippy::arc_with_non_send_sync)]
5186 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5187 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5188
5189 let task = cx.update(|cx| {
5190 tool.run(
5191 crate::DeletePathToolInput {
5192 path: "root/important_data.txt".to_string(),
5193 },
5194 event_stream,
5195 cx,
5196 )
5197 });
5198
5199 let result = task.await;
5200 assert!(result.is_err(), "expected deletion to be blocked");
5201 assert!(
5202 result.unwrap_err().to_string().contains("blocked"),
5203 "error should mention the deletion was blocked"
5204 );
5205}
5206
5207#[gpui::test]
5208async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5209 init_test(cx);
5210
5211 let fs = FakeFs::new(cx.executor());
5212 fs.insert_tree(
5213 "/root",
5214 json!({
5215 "safe.txt": "content",
5216 "protected": {}
5217 }),
5218 )
5219 .await;
5220 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5221
5222 cx.update(|cx| {
5223 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5224 settings.tool_permissions.tools.insert(
5225 "move_path".into(),
5226 agent_settings::ToolRules {
5227 default_mode: settings::ToolPermissionMode::Allow,
5228 always_allow: vec![],
5229 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5230 always_confirm: vec![],
5231 invalid_patterns: vec![],
5232 },
5233 );
5234 agent_settings::AgentSettings::override_global(settings, cx);
5235 });
5236
5237 #[allow(clippy::arc_with_non_send_sync)]
5238 let tool = Arc::new(crate::MovePathTool::new(project));
5239 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5240
5241 let task = cx.update(|cx| {
5242 tool.run(
5243 crate::MovePathToolInput {
5244 source_path: "root/safe.txt".to_string(),
5245 destination_path: "root/protected/safe.txt".to_string(),
5246 },
5247 event_stream,
5248 cx,
5249 )
5250 });
5251
5252 let result = task.await;
5253 assert!(
5254 result.is_err(),
5255 "expected move to be blocked due to destination path"
5256 );
5257 assert!(
5258 result.unwrap_err().to_string().contains("blocked"),
5259 "error should mention the move was blocked"
5260 );
5261}
5262
5263#[gpui::test]
5264async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5265 init_test(cx);
5266
5267 let fs = FakeFs::new(cx.executor());
5268 fs.insert_tree(
5269 "/root",
5270 json!({
5271 "secret.txt": "secret content",
5272 "public": {}
5273 }),
5274 )
5275 .await;
5276 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5277
5278 cx.update(|cx| {
5279 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5280 settings.tool_permissions.tools.insert(
5281 "move_path".into(),
5282 agent_settings::ToolRules {
5283 default_mode: settings::ToolPermissionMode::Allow,
5284 always_allow: vec![],
5285 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5286 always_confirm: vec![],
5287 invalid_patterns: vec![],
5288 },
5289 );
5290 agent_settings::AgentSettings::override_global(settings, cx);
5291 });
5292
5293 #[allow(clippy::arc_with_non_send_sync)]
5294 let tool = Arc::new(crate::MovePathTool::new(project));
5295 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5296
5297 let task = cx.update(|cx| {
5298 tool.run(
5299 crate::MovePathToolInput {
5300 source_path: "root/secret.txt".to_string(),
5301 destination_path: "root/public/not_secret.txt".to_string(),
5302 },
5303 event_stream,
5304 cx,
5305 )
5306 });
5307
5308 let result = task.await;
5309 assert!(
5310 result.is_err(),
5311 "expected move to be blocked due to source path"
5312 );
5313 assert!(
5314 result.unwrap_err().to_string().contains("blocked"),
5315 "error should mention the move was blocked"
5316 );
5317}
5318
5319#[gpui::test]
5320async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5321 init_test(cx);
5322
5323 let fs = FakeFs::new(cx.executor());
5324 fs.insert_tree(
5325 "/root",
5326 json!({
5327 "confidential.txt": "confidential data",
5328 "dest": {}
5329 }),
5330 )
5331 .await;
5332 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5333
5334 cx.update(|cx| {
5335 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5336 settings.tool_permissions.tools.insert(
5337 "copy_path".into(),
5338 agent_settings::ToolRules {
5339 default_mode: settings::ToolPermissionMode::Allow,
5340 always_allow: vec![],
5341 always_deny: vec![
5342 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5343 ],
5344 always_confirm: vec![],
5345 invalid_patterns: vec![],
5346 },
5347 );
5348 agent_settings::AgentSettings::override_global(settings, cx);
5349 });
5350
5351 #[allow(clippy::arc_with_non_send_sync)]
5352 let tool = Arc::new(crate::CopyPathTool::new(project));
5353 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5354
5355 let task = cx.update(|cx| {
5356 tool.run(
5357 crate::CopyPathToolInput {
5358 source_path: "root/confidential.txt".to_string(),
5359 destination_path: "root/dest/copy.txt".to_string(),
5360 },
5361 event_stream,
5362 cx,
5363 )
5364 });
5365
5366 let result = task.await;
5367 assert!(result.is_err(), "expected copy to be blocked");
5368 assert!(
5369 result.unwrap_err().to_string().contains("blocked"),
5370 "error should mention the copy was blocked"
5371 );
5372}
5373
5374#[gpui::test]
5375async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5376 init_test(cx);
5377
5378 let fs = FakeFs::new(cx.executor());
5379 fs.insert_tree(
5380 "/root",
5381 json!({
5382 "normal.txt": "normal content",
5383 "readonly": {
5384 "config.txt": "readonly content"
5385 }
5386 }),
5387 )
5388 .await;
5389 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5390
5391 cx.update(|cx| {
5392 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5393 settings.tool_permissions.tools.insert(
5394 "save_file".into(),
5395 agent_settings::ToolRules {
5396 default_mode: settings::ToolPermissionMode::Allow,
5397 always_allow: vec![],
5398 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5399 always_confirm: vec![],
5400 invalid_patterns: vec![],
5401 },
5402 );
5403 agent_settings::AgentSettings::override_global(settings, cx);
5404 });
5405
5406 #[allow(clippy::arc_with_non_send_sync)]
5407 let tool = Arc::new(crate::SaveFileTool::new(project));
5408 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5409
5410 let task = cx.update(|cx| {
5411 tool.run(
5412 crate::SaveFileToolInput {
5413 paths: vec![
5414 std::path::PathBuf::from("root/normal.txt"),
5415 std::path::PathBuf::from("root/readonly/config.txt"),
5416 ],
5417 },
5418 event_stream,
5419 cx,
5420 )
5421 });
5422
5423 let result = task.await;
5424 assert!(
5425 result.is_err(),
5426 "expected save to be blocked due to denied path"
5427 );
5428 assert!(
5429 result.unwrap_err().to_string().contains("blocked"),
5430 "error should mention the save was blocked"
5431 );
5432}
5433
5434#[gpui::test]
5435async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5436 init_test(cx);
5437
5438 let fs = FakeFs::new(cx.executor());
5439 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5440 .await;
5441 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5442
5443 cx.update(|cx| {
5444 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5445 settings.always_allow_tool_actions = false;
5446 settings.tool_permissions.tools.insert(
5447 "save_file".into(),
5448 agent_settings::ToolRules {
5449 default_mode: settings::ToolPermissionMode::Allow,
5450 always_allow: vec![],
5451 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5452 always_confirm: vec![],
5453 invalid_patterns: vec![],
5454 },
5455 );
5456 agent_settings::AgentSettings::override_global(settings, cx);
5457 });
5458
5459 #[allow(clippy::arc_with_non_send_sync)]
5460 let tool = Arc::new(crate::SaveFileTool::new(project));
5461 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5462
5463 let task = cx.update(|cx| {
5464 tool.run(
5465 crate::SaveFileToolInput {
5466 paths: vec![std::path::PathBuf::from("root/config.secret")],
5467 },
5468 event_stream,
5469 cx,
5470 )
5471 });
5472
5473 let result = task.await;
5474 assert!(result.is_err(), "expected save to be blocked");
5475 assert!(
5476 result.unwrap_err().to_string().contains("blocked"),
5477 "error should mention the save was blocked"
5478 );
5479}
5480
5481#[gpui::test]
5482async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5483 init_test(cx);
5484
5485 cx.update(|cx| {
5486 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5487 settings.tool_permissions.tools.insert(
5488 "web_search".into(),
5489 agent_settings::ToolRules {
5490 default_mode: settings::ToolPermissionMode::Allow,
5491 always_allow: vec![],
5492 always_deny: vec![
5493 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5494 ],
5495 always_confirm: vec![],
5496 invalid_patterns: vec![],
5497 },
5498 );
5499 agent_settings::AgentSettings::override_global(settings, cx);
5500 });
5501
5502 #[allow(clippy::arc_with_non_send_sync)]
5503 let tool = Arc::new(crate::WebSearchTool);
5504 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5505
5506 let input: crate::WebSearchToolInput =
5507 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5508
5509 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5510
5511 let result = task.await;
5512 assert!(result.is_err(), "expected search to be blocked");
5513 assert!(
5514 result.unwrap_err().to_string().contains("blocked"),
5515 "error should mention the search was blocked"
5516 );
5517}
5518
5519#[gpui::test]
5520async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5521 init_test(cx);
5522
5523 let fs = FakeFs::new(cx.executor());
5524 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5525 .await;
5526 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5527
5528 cx.update(|cx| {
5529 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5530 settings.always_allow_tool_actions = false;
5531 settings.tool_permissions.tools.insert(
5532 "edit_file".into(),
5533 agent_settings::ToolRules {
5534 default_mode: settings::ToolPermissionMode::Confirm,
5535 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5536 always_deny: vec![],
5537 always_confirm: vec![],
5538 invalid_patterns: vec![],
5539 },
5540 );
5541 agent_settings::AgentSettings::override_global(settings, cx);
5542 });
5543
5544 let context_server_registry =
5545 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5546 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5547 let templates = crate::Templates::new();
5548 let thread = cx.new(|cx| {
5549 crate::Thread::new(
5550 project.clone(),
5551 cx.new(|_cx| prompt_store::ProjectContext::default()),
5552 context_server_registry,
5553 templates.clone(),
5554 None,
5555 cx,
5556 )
5557 });
5558
5559 #[allow(clippy::arc_with_non_send_sync)]
5560 let tool = Arc::new(crate::EditFileTool::new(
5561 project,
5562 thread.downgrade(),
5563 language_registry,
5564 templates,
5565 ));
5566 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5567
5568 let _task = cx.update(|cx| {
5569 tool.run(
5570 crate::EditFileToolInput {
5571 display_description: "Edit README".to_string(),
5572 path: "root/README.md".into(),
5573 mode: crate::EditFileMode::Edit,
5574 },
5575 event_stream,
5576 cx,
5577 )
5578 });
5579
5580 cx.run_until_parked();
5581
5582 let event = rx.try_next();
5583 assert!(
5584 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5585 "expected no authorization request for allowed .md file"
5586 );
5587}
5588
5589#[gpui::test]
5590async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5591 init_test(cx);
5592
5593 cx.update(|cx| {
5594 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5595 settings.tool_permissions.tools.insert(
5596 "fetch".into(),
5597 agent_settings::ToolRules {
5598 default_mode: settings::ToolPermissionMode::Allow,
5599 always_allow: vec![],
5600 always_deny: vec![
5601 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5602 ],
5603 always_confirm: vec![],
5604 invalid_patterns: vec![],
5605 },
5606 );
5607 agent_settings::AgentSettings::override_global(settings, cx);
5608 });
5609
5610 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5611
5612 #[allow(clippy::arc_with_non_send_sync)]
5613 let tool = Arc::new(crate::FetchTool::new(http_client));
5614 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5615
5616 let input: crate::FetchToolInput =
5617 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5618
5619 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5620
5621 let result = task.await;
5622 assert!(result.is_err(), "expected fetch to be blocked");
5623 assert!(
5624 result.unwrap_err().to_string().contains("blocked"),
5625 "error should mention the fetch was blocked"
5626 );
5627}
5628
5629#[gpui::test]
5630async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5631 init_test(cx);
5632
5633 cx.update(|cx| {
5634 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5635 settings.always_allow_tool_actions = false;
5636 settings.tool_permissions.tools.insert(
5637 "fetch".into(),
5638 agent_settings::ToolRules {
5639 default_mode: settings::ToolPermissionMode::Confirm,
5640 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5641 always_deny: vec![],
5642 always_confirm: vec![],
5643 invalid_patterns: vec![],
5644 },
5645 );
5646 agent_settings::AgentSettings::override_global(settings, cx);
5647 });
5648
5649 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5650
5651 #[allow(clippy::arc_with_non_send_sync)]
5652 let tool = Arc::new(crate::FetchTool::new(http_client));
5653 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5654
5655 let input: crate::FetchToolInput =
5656 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5657
5658 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5659
5660 cx.run_until_parked();
5661
5662 let event = rx.try_next();
5663 assert!(
5664 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5665 "expected no authorization request for allowed docs.rs URL"
5666 );
5667}
5668
5669#[gpui::test]
5670async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5671 init_test(cx);
5672 always_allow_tools(cx);
5673
5674 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5675 let fake_model = model.as_fake();
5676
5677 // Add a tool so we can simulate tool calls
5678 thread.update(cx, |thread, _cx| {
5679 thread.add_tool(EchoTool);
5680 });
5681
5682 // Start a turn by sending a message
5683 let mut events = thread
5684 .update(cx, |thread, cx| {
5685 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5686 })
5687 .unwrap();
5688 cx.run_until_parked();
5689
5690 // Simulate the model making a tool call
5691 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5692 LanguageModelToolUse {
5693 id: "tool_1".into(),
5694 name: "echo".into(),
5695 raw_input: r#"{"text": "hello"}"#.into(),
5696 input: json!({"text": "hello"}),
5697 is_input_complete: true,
5698 thought_signature: None,
5699 },
5700 ));
5701 fake_model
5702 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5703
5704 // Queue a message before ending the stream
5705 thread.update(cx, |thread, _cx| {
5706 thread.queue_message(
5707 vec![acp::ContentBlock::Text(acp::TextContent::new(
5708 "This is my queued message".to_string(),
5709 ))],
5710 vec![],
5711 );
5712 });
5713
5714 // Now end the stream - tool will run, and the boundary check should see the queue
5715 fake_model.end_last_completion_stream();
5716
5717 // Collect all events until the turn stops
5718 let all_events = collect_events_until_stop(&mut events, cx).await;
5719
5720 // Verify we received the tool call event
5721 let tool_call_ids: Vec<_> = all_events
5722 .iter()
5723 .filter_map(|e| match e {
5724 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5725 _ => None,
5726 })
5727 .collect();
5728 assert_eq!(
5729 tool_call_ids,
5730 vec!["tool_1"],
5731 "Should have received a tool call event for our echo tool"
5732 );
5733
5734 // The turn should have stopped with EndTurn
5735 let stop_reasons = stop_events(all_events);
5736 assert_eq!(
5737 stop_reasons,
5738 vec![acp::StopReason::EndTurn],
5739 "Turn should have ended after tool completion due to queued message"
5740 );
5741
5742 // Verify the queued message is still there
5743 thread.update(cx, |thread, _cx| {
5744 let queued = thread.queued_messages();
5745 assert_eq!(queued.len(), 1, "Should still have one queued message");
5746 assert!(matches!(
5747 &queued[0].content[0],
5748 acp::ContentBlock::Text(t) if t.text == "This is my queued message"
5749 ));
5750 });
5751
5752 // Thread should be idle now
5753 thread.update(cx, |thread, _cx| {
5754 assert!(
5755 thread.is_turn_complete(),
5756 "Thread should not be running after turn ends"
5757 );
5758 });
5759}