1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use feature_flags::FeatureFlagAppExt as _;
11use fs::{FakeFs, Fs};
12use futures::{
13 FutureExt as _, StreamExt,
14 channel::{
15 mpsc::{self, UnboundedReceiver},
16 oneshot,
17 },
18 future::{Fuse, Shared},
19};
20use gpui::{
21 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
22 http_client::FakeHttpClient,
23};
24use indoc::indoc;
25use language_model::{
26 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
27 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
28 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
29 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
30};
31use pretty_assertions::assert_eq;
32use project::{
33 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
34};
35use prompt_store::ProjectContext;
36use reqwest_client::ReqwestClient;
37use schemars::JsonSchema;
38use serde::{Deserialize, Serialize};
39use serde_json::json;
40use settings::{Settings, SettingsStore};
41use std::{
42 path::Path,
43 pin::Pin,
44 rc::Rc,
45 sync::{
46 Arc,
47 atomic::{AtomicBool, Ordering},
48 },
49 time::Duration,
50};
51use util::path;
52
53mod test_tools;
54use test_tools::*;
55
56fn init_test(cx: &mut TestAppContext) {
57 cx.update(|cx| {
58 let settings_store = SettingsStore::test(cx);
59 cx.set_global(settings_store);
60 });
61}
62
63struct FakeTerminalHandle {
64 killed: Arc<AtomicBool>,
65 stopped_by_user: Arc<AtomicBool>,
66 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
67 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
68 output: acp::TerminalOutputResponse,
69 id: acp::TerminalId,
70}
71
72impl FakeTerminalHandle {
73 fn new_never_exits(cx: &mut App) -> Self {
74 let killed = Arc::new(AtomicBool::new(false));
75 let stopped_by_user = Arc::new(AtomicBool::new(false));
76
77 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
78
79 let wait_for_exit = cx
80 .spawn(async move |_cx| {
81 // Wait for the exit signal (sent when kill() is called)
82 let _ = exit_receiver.await;
83 acp::TerminalExitStatus::new()
84 })
85 .shared();
86
87 Self {
88 killed,
89 stopped_by_user,
90 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
91 wait_for_exit,
92 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
93 id: acp::TerminalId::new("fake_terminal".to_string()),
94 }
95 }
96
97 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
98 let killed = Arc::new(AtomicBool::new(false));
99 let stopped_by_user = Arc::new(AtomicBool::new(false));
100 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
101
102 let wait_for_exit = cx
103 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
104 .shared();
105
106 Self {
107 killed,
108 stopped_by_user,
109 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
110 wait_for_exit,
111 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
112 id: acp::TerminalId::new("fake_terminal".to_string()),
113 }
114 }
115
116 fn was_killed(&self) -> bool {
117 self.killed.load(Ordering::SeqCst)
118 }
119
120 fn set_stopped_by_user(&self, stopped: bool) {
121 self.stopped_by_user.store(stopped, Ordering::SeqCst);
122 }
123
124 fn signal_exit(&self) {
125 if let Some(sender) = self.exit_sender.borrow_mut().take() {
126 let _ = sender.send(());
127 }
128 }
129}
130
131impl crate::TerminalHandle for FakeTerminalHandle {
132 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
133 Ok(self.id.clone())
134 }
135
136 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
137 Ok(self.output.clone())
138 }
139
140 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
141 Ok(self.wait_for_exit.clone())
142 }
143
144 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
145 self.killed.store(true, Ordering::SeqCst);
146 self.signal_exit();
147 Ok(())
148 }
149
150 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
151 Ok(self.stopped_by_user.load(Ordering::SeqCst))
152 }
153}
154
155struct FakeThreadEnvironment {
156 handle: Rc<FakeTerminalHandle>,
157}
158
159impl crate::ThreadEnvironment for FakeThreadEnvironment {
160 fn create_terminal(
161 &self,
162 _command: String,
163 _cwd: Option<std::path::PathBuf>,
164 _output_byte_limit: Option<u64>,
165 _cx: &mut AsyncApp,
166 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
167 Task::ready(Ok(self.handle.clone() as Rc<dyn crate::TerminalHandle>))
168 }
169}
170
171/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
172struct MultiTerminalEnvironment {
173 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
174}
175
176impl MultiTerminalEnvironment {
177 fn new() -> Self {
178 Self {
179 handles: std::cell::RefCell::new(Vec::new()),
180 }
181 }
182
183 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
184 self.handles.borrow().clone()
185 }
186}
187
188impl crate::ThreadEnvironment for MultiTerminalEnvironment {
189 fn create_terminal(
190 &self,
191 _command: String,
192 _cwd: Option<std::path::PathBuf>,
193 _output_byte_limit: Option<u64>,
194 cx: &mut AsyncApp,
195 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
196 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
197 self.handles.borrow_mut().push(handle.clone());
198 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
199 }
200}
201
202fn always_allow_tools(cx: &mut TestAppContext) {
203 cx.update(|cx| {
204 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
205 settings.always_allow_tool_actions = true;
206 agent_settings::AgentSettings::override_global(settings, cx);
207 });
208}
209
210#[gpui::test]
211async fn test_echo(cx: &mut TestAppContext) {
212 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
213 let fake_model = model.as_fake();
214
215 let events = thread
216 .update(cx, |thread, cx| {
217 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
218 })
219 .unwrap();
220 cx.run_until_parked();
221 fake_model.send_last_completion_stream_text_chunk("Hello");
222 fake_model
223 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
224 fake_model.end_last_completion_stream();
225
226 let events = events.collect().await;
227 thread.update(cx, |thread, _cx| {
228 assert_eq!(
229 thread.last_message().unwrap().to_markdown(),
230 indoc! {"
231 ## Assistant
232
233 Hello
234 "}
235 )
236 });
237 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
238}
239
240#[gpui::test]
241async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
242 init_test(cx);
243 always_allow_tools(cx);
244
245 let fs = FakeFs::new(cx.executor());
246 let project = Project::test(fs, [], cx).await;
247
248 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
249 let environment = Rc::new(FakeThreadEnvironment {
250 handle: handle.clone(),
251 });
252
253 #[allow(clippy::arc_with_non_send_sync)]
254 let tool = Arc::new(crate::TerminalTool::new(project, environment));
255 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
256
257 let task = cx.update(|cx| {
258 tool.run(
259 crate::TerminalToolInput {
260 command: "sleep 1000".to_string(),
261 cd: ".".to_string(),
262 timeout_ms: Some(5),
263 },
264 event_stream,
265 cx,
266 )
267 });
268
269 let update = rx.expect_update_fields().await;
270 assert!(
271 update.content.iter().any(|blocks| {
272 blocks
273 .iter()
274 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
275 }),
276 "expected tool call update to include terminal content"
277 );
278
279 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
280
281 let deadline = std::time::Instant::now() + Duration::from_millis(500);
282 loop {
283 if let Some(result) = task_future.as_mut().now_or_never() {
284 let result = result.expect("terminal tool task should complete");
285
286 assert!(
287 handle.was_killed(),
288 "expected terminal handle to be killed on timeout"
289 );
290 assert!(
291 result.contains("partial output"),
292 "expected result to include terminal output, got: {result}"
293 );
294 return;
295 }
296
297 if std::time::Instant::now() >= deadline {
298 panic!("timed out waiting for terminal tool task to complete");
299 }
300
301 cx.run_until_parked();
302 cx.background_executor.timer(Duration::from_millis(1)).await;
303 }
304}
305
306#[gpui::test]
307#[ignore]
308async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
309 init_test(cx);
310 always_allow_tools(cx);
311
312 let fs = FakeFs::new(cx.executor());
313 let project = Project::test(fs, [], cx).await;
314
315 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
316 let environment = Rc::new(FakeThreadEnvironment {
317 handle: handle.clone(),
318 });
319
320 #[allow(clippy::arc_with_non_send_sync)]
321 let tool = Arc::new(crate::TerminalTool::new(project, environment));
322 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
323
324 let _task = cx.update(|cx| {
325 tool.run(
326 crate::TerminalToolInput {
327 command: "sleep 1000".to_string(),
328 cd: ".".to_string(),
329 timeout_ms: None,
330 },
331 event_stream,
332 cx,
333 )
334 });
335
336 let update = rx.expect_update_fields().await;
337 assert!(
338 update.content.iter().any(|blocks| {
339 blocks
340 .iter()
341 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
342 }),
343 "expected tool call update to include terminal content"
344 );
345
346 cx.background_executor
347 .timer(Duration::from_millis(25))
348 .await;
349
350 assert!(
351 !handle.was_killed(),
352 "did not expect terminal handle to be killed without a timeout"
353 );
354}
355
356#[gpui::test]
357async fn test_thinking(cx: &mut TestAppContext) {
358 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
359 let fake_model = model.as_fake();
360
361 let events = thread
362 .update(cx, |thread, cx| {
363 thread.send(
364 UserMessageId::new(),
365 [indoc! {"
366 Testing:
367
368 Generate a thinking step where you just think the word 'Think',
369 and have your final answer be 'Hello'
370 "}],
371 cx,
372 )
373 })
374 .unwrap();
375 cx.run_until_parked();
376 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
377 text: "Think".to_string(),
378 signature: None,
379 });
380 fake_model.send_last_completion_stream_text_chunk("Hello");
381 fake_model
382 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
383 fake_model.end_last_completion_stream();
384
385 let events = events.collect().await;
386 thread.update(cx, |thread, _cx| {
387 assert_eq!(
388 thread.last_message().unwrap().to_markdown(),
389 indoc! {"
390 ## Assistant
391
392 <think>Think</think>
393 Hello
394 "}
395 )
396 });
397 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
398}
399
400#[gpui::test]
401async fn test_system_prompt(cx: &mut TestAppContext) {
402 let ThreadTest {
403 model,
404 thread,
405 project_context,
406 ..
407 } = setup(cx, TestModel::Fake).await;
408 let fake_model = model.as_fake();
409
410 project_context.update(cx, |project_context, _cx| {
411 project_context.shell = "test-shell".into()
412 });
413 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
414 thread
415 .update(cx, |thread, cx| {
416 thread.send(UserMessageId::new(), ["abc"], cx)
417 })
418 .unwrap();
419 cx.run_until_parked();
420 let mut pending_completions = fake_model.pending_completions();
421 assert_eq!(
422 pending_completions.len(),
423 1,
424 "unexpected pending completions: {:?}",
425 pending_completions
426 );
427
428 let pending_completion = pending_completions.pop().unwrap();
429 assert_eq!(pending_completion.messages[0].role, Role::System);
430
431 let system_message = &pending_completion.messages[0];
432 let system_prompt = system_message.content[0].to_str().unwrap();
433 assert!(
434 system_prompt.contains("test-shell"),
435 "unexpected system message: {:?}",
436 system_message
437 );
438 assert!(
439 system_prompt.contains("## Fixing Diagnostics"),
440 "unexpected system message: {:?}",
441 system_message
442 );
443}
444
445#[gpui::test]
446async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
447 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
448 let fake_model = model.as_fake();
449
450 thread
451 .update(cx, |thread, cx| {
452 thread.send(UserMessageId::new(), ["abc"], cx)
453 })
454 .unwrap();
455 cx.run_until_parked();
456 let mut pending_completions = fake_model.pending_completions();
457 assert_eq!(
458 pending_completions.len(),
459 1,
460 "unexpected pending completions: {:?}",
461 pending_completions
462 );
463
464 let pending_completion = pending_completions.pop().unwrap();
465 assert_eq!(pending_completion.messages[0].role, Role::System);
466
467 let system_message = &pending_completion.messages[0];
468 let system_prompt = system_message.content[0].to_str().unwrap();
469 assert!(
470 !system_prompt.contains("## Tool Use"),
471 "unexpected system message: {:?}",
472 system_message
473 );
474 assert!(
475 !system_prompt.contains("## Fixing Diagnostics"),
476 "unexpected system message: {:?}",
477 system_message
478 );
479}
480
481#[gpui::test]
482async fn test_prompt_caching(cx: &mut TestAppContext) {
483 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
484 let fake_model = model.as_fake();
485
486 // Send initial user message and verify it's cached
487 thread
488 .update(cx, |thread, cx| {
489 thread.send(UserMessageId::new(), ["Message 1"], cx)
490 })
491 .unwrap();
492 cx.run_until_parked();
493
494 let completion = fake_model.pending_completions().pop().unwrap();
495 assert_eq!(
496 completion.messages[1..],
497 vec![LanguageModelRequestMessage {
498 role: Role::User,
499 content: vec!["Message 1".into()],
500 cache: true,
501 reasoning_details: None,
502 }]
503 );
504 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
505 "Response to Message 1".into(),
506 ));
507 fake_model.end_last_completion_stream();
508 cx.run_until_parked();
509
510 // Send another user message and verify only the latest is cached
511 thread
512 .update(cx, |thread, cx| {
513 thread.send(UserMessageId::new(), ["Message 2"], cx)
514 })
515 .unwrap();
516 cx.run_until_parked();
517
518 let completion = fake_model.pending_completions().pop().unwrap();
519 assert_eq!(
520 completion.messages[1..],
521 vec![
522 LanguageModelRequestMessage {
523 role: Role::User,
524 content: vec!["Message 1".into()],
525 cache: false,
526 reasoning_details: None,
527 },
528 LanguageModelRequestMessage {
529 role: Role::Assistant,
530 content: vec!["Response to Message 1".into()],
531 cache: false,
532 reasoning_details: None,
533 },
534 LanguageModelRequestMessage {
535 role: Role::User,
536 content: vec!["Message 2".into()],
537 cache: true,
538 reasoning_details: None,
539 }
540 ]
541 );
542 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
543 "Response to Message 2".into(),
544 ));
545 fake_model.end_last_completion_stream();
546 cx.run_until_parked();
547
548 // Simulate a tool call and verify that the latest tool result is cached
549 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
550 thread
551 .update(cx, |thread, cx| {
552 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
553 })
554 .unwrap();
555 cx.run_until_parked();
556
557 let tool_use = LanguageModelToolUse {
558 id: "tool_1".into(),
559 name: EchoTool::name().into(),
560 raw_input: json!({"text": "test"}).to_string(),
561 input: json!({"text": "test"}),
562 is_input_complete: true,
563 thought_signature: None,
564 };
565 fake_model
566 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
567 fake_model.end_last_completion_stream();
568 cx.run_until_parked();
569
570 let completion = fake_model.pending_completions().pop().unwrap();
571 let tool_result = LanguageModelToolResult {
572 tool_use_id: "tool_1".into(),
573 tool_name: EchoTool::name().into(),
574 is_error: false,
575 content: "test".into(),
576 output: Some("test".into()),
577 };
578 assert_eq!(
579 completion.messages[1..],
580 vec![
581 LanguageModelRequestMessage {
582 role: Role::User,
583 content: vec!["Message 1".into()],
584 cache: false,
585 reasoning_details: None,
586 },
587 LanguageModelRequestMessage {
588 role: Role::Assistant,
589 content: vec!["Response to Message 1".into()],
590 cache: false,
591 reasoning_details: None,
592 },
593 LanguageModelRequestMessage {
594 role: Role::User,
595 content: vec!["Message 2".into()],
596 cache: false,
597 reasoning_details: None,
598 },
599 LanguageModelRequestMessage {
600 role: Role::Assistant,
601 content: vec!["Response to Message 2".into()],
602 cache: false,
603 reasoning_details: None,
604 },
605 LanguageModelRequestMessage {
606 role: Role::User,
607 content: vec!["Use the echo tool".into()],
608 cache: false,
609 reasoning_details: None,
610 },
611 LanguageModelRequestMessage {
612 role: Role::Assistant,
613 content: vec![MessageContent::ToolUse(tool_use)],
614 cache: false,
615 reasoning_details: None,
616 },
617 LanguageModelRequestMessage {
618 role: Role::User,
619 content: vec![MessageContent::ToolResult(tool_result)],
620 cache: true,
621 reasoning_details: None,
622 }
623 ]
624 );
625}
626
627#[gpui::test]
628#[cfg_attr(not(feature = "e2e"), ignore)]
629async fn test_basic_tool_calls(cx: &mut TestAppContext) {
630 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
631
632 // Test a tool call that's likely to complete *before* streaming stops.
633 let events = thread
634 .update(cx, |thread, cx| {
635 thread.add_tool(EchoTool);
636 thread.send(
637 UserMessageId::new(),
638 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
639 cx,
640 )
641 })
642 .unwrap()
643 .collect()
644 .await;
645 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
646
647 // Test a tool calls that's likely to complete *after* streaming stops.
648 let events = thread
649 .update(cx, |thread, cx| {
650 thread.remove_tool(&EchoTool::name());
651 thread.add_tool(DelayTool);
652 thread.send(
653 UserMessageId::new(),
654 [
655 "Now call the delay tool with 200ms.",
656 "When the timer goes off, then you echo the output of the tool.",
657 ],
658 cx,
659 )
660 })
661 .unwrap()
662 .collect()
663 .await;
664 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
665 thread.update(cx, |thread, _cx| {
666 assert!(
667 thread
668 .last_message()
669 .unwrap()
670 .as_agent_message()
671 .unwrap()
672 .content
673 .iter()
674 .any(|content| {
675 if let AgentMessageContent::Text(text) = content {
676 text.contains("Ding")
677 } else {
678 false
679 }
680 }),
681 "{}",
682 thread.to_markdown()
683 );
684 });
685}
686
687#[gpui::test]
688#[cfg_attr(not(feature = "e2e"), ignore)]
689async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
690 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
691
692 // Test a tool call that's likely to complete *before* streaming stops.
693 let mut events = thread
694 .update(cx, |thread, cx| {
695 thread.add_tool(WordListTool);
696 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
697 })
698 .unwrap();
699
700 let mut saw_partial_tool_use = false;
701 while let Some(event) = events.next().await {
702 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
703 thread.update(cx, |thread, _cx| {
704 // Look for a tool use in the thread's last message
705 let message = thread.last_message().unwrap();
706 let agent_message = message.as_agent_message().unwrap();
707 let last_content = agent_message.content.last().unwrap();
708 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
709 assert_eq!(last_tool_use.name.as_ref(), "word_list");
710 if tool_call.status == acp::ToolCallStatus::Pending {
711 if !last_tool_use.is_input_complete
712 && last_tool_use.input.get("g").is_none()
713 {
714 saw_partial_tool_use = true;
715 }
716 } else {
717 last_tool_use
718 .input
719 .get("a")
720 .expect("'a' has streamed because input is now complete");
721 last_tool_use
722 .input
723 .get("g")
724 .expect("'g' has streamed because input is now complete");
725 }
726 } else {
727 panic!("last content should be a tool use");
728 }
729 });
730 }
731 }
732
733 assert!(
734 saw_partial_tool_use,
735 "should see at least one partially streamed tool use in the history"
736 );
737}
738
739#[gpui::test]
740async fn test_tool_authorization(cx: &mut TestAppContext) {
741 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
742 let fake_model = model.as_fake();
743
744 let mut events = thread
745 .update(cx, |thread, cx| {
746 thread.add_tool(ToolRequiringPermission);
747 thread.send(UserMessageId::new(), ["abc"], cx)
748 })
749 .unwrap();
750 cx.run_until_parked();
751 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
752 LanguageModelToolUse {
753 id: "tool_id_1".into(),
754 name: ToolRequiringPermission::name().into(),
755 raw_input: "{}".into(),
756 input: json!({}),
757 is_input_complete: true,
758 thought_signature: None,
759 },
760 ));
761 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
762 LanguageModelToolUse {
763 id: "tool_id_2".into(),
764 name: ToolRequiringPermission::name().into(),
765 raw_input: "{}".into(),
766 input: json!({}),
767 is_input_complete: true,
768 thought_signature: None,
769 },
770 ));
771 fake_model.end_last_completion_stream();
772 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
773 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
774
775 // Approve the first
776 tool_call_auth_1
777 .response
778 .send(tool_call_auth_1.options[1].option_id.clone())
779 .unwrap();
780 cx.run_until_parked();
781
782 // Reject the second
783 tool_call_auth_2
784 .response
785 .send(tool_call_auth_1.options[2].option_id.clone())
786 .unwrap();
787 cx.run_until_parked();
788
789 let completion = fake_model.pending_completions().pop().unwrap();
790 let message = completion.messages.last().unwrap();
791 assert_eq!(
792 message.content,
793 vec![
794 language_model::MessageContent::ToolResult(LanguageModelToolResult {
795 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
796 tool_name: ToolRequiringPermission::name().into(),
797 is_error: false,
798 content: "Allowed".into(),
799 output: Some("Allowed".into())
800 }),
801 language_model::MessageContent::ToolResult(LanguageModelToolResult {
802 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
803 tool_name: ToolRequiringPermission::name().into(),
804 is_error: true,
805 content: "Permission to run tool denied by user".into(),
806 output: Some("Permission to run tool denied by user".into())
807 })
808 ]
809 );
810
811 // Simulate yet another tool call.
812 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
813 LanguageModelToolUse {
814 id: "tool_id_3".into(),
815 name: ToolRequiringPermission::name().into(),
816 raw_input: "{}".into(),
817 input: json!({}),
818 is_input_complete: true,
819 thought_signature: None,
820 },
821 ));
822 fake_model.end_last_completion_stream();
823
824 // Respond by always allowing tools.
825 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
826 tool_call_auth_3
827 .response
828 .send(tool_call_auth_3.options[0].option_id.clone())
829 .unwrap();
830 cx.run_until_parked();
831 let completion = fake_model.pending_completions().pop().unwrap();
832 let message = completion.messages.last().unwrap();
833 assert_eq!(
834 message.content,
835 vec![language_model::MessageContent::ToolResult(
836 LanguageModelToolResult {
837 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
838 tool_name: ToolRequiringPermission::name().into(),
839 is_error: false,
840 content: "Allowed".into(),
841 output: Some("Allowed".into())
842 }
843 )]
844 );
845
846 // Simulate a final tool call, ensuring we don't trigger authorization.
847 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
848 LanguageModelToolUse {
849 id: "tool_id_4".into(),
850 name: ToolRequiringPermission::name().into(),
851 raw_input: "{}".into(),
852 input: json!({}),
853 is_input_complete: true,
854 thought_signature: None,
855 },
856 ));
857 fake_model.end_last_completion_stream();
858 cx.run_until_parked();
859 let completion = fake_model.pending_completions().pop().unwrap();
860 let message = completion.messages.last().unwrap();
861 assert_eq!(
862 message.content,
863 vec![language_model::MessageContent::ToolResult(
864 LanguageModelToolResult {
865 tool_use_id: "tool_id_4".into(),
866 tool_name: ToolRequiringPermission::name().into(),
867 is_error: false,
868 content: "Allowed".into(),
869 output: Some("Allowed".into())
870 }
871 )]
872 );
873}
874
875#[gpui::test]
876async fn test_tool_hallucination(cx: &mut TestAppContext) {
877 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
878 let fake_model = model.as_fake();
879
880 let mut events = thread
881 .update(cx, |thread, cx| {
882 thread.send(UserMessageId::new(), ["abc"], cx)
883 })
884 .unwrap();
885 cx.run_until_parked();
886 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
887 LanguageModelToolUse {
888 id: "tool_id_1".into(),
889 name: "nonexistent_tool".into(),
890 raw_input: "{}".into(),
891 input: json!({}),
892 is_input_complete: true,
893 thought_signature: None,
894 },
895 ));
896 fake_model.end_last_completion_stream();
897
898 let tool_call = expect_tool_call(&mut events).await;
899 assert_eq!(tool_call.title, "nonexistent_tool");
900 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
901 let update = expect_tool_call_update_fields(&mut events).await;
902 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
903}
904
905#[gpui::test]
906async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
907 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
908 let fake_model = model.as_fake();
909
910 let events = thread
911 .update(cx, |thread, cx| {
912 thread.add_tool(EchoTool);
913 thread.send(UserMessageId::new(), ["abc"], cx)
914 })
915 .unwrap();
916 cx.run_until_parked();
917 let tool_use = LanguageModelToolUse {
918 id: "tool_id_1".into(),
919 name: EchoTool::name().into(),
920 raw_input: "{}".into(),
921 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
922 is_input_complete: true,
923 thought_signature: None,
924 };
925 fake_model
926 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
927 fake_model.end_last_completion_stream();
928
929 cx.run_until_parked();
930 let completion = fake_model.pending_completions().pop().unwrap();
931 let tool_result = LanguageModelToolResult {
932 tool_use_id: "tool_id_1".into(),
933 tool_name: EchoTool::name().into(),
934 is_error: false,
935 content: "def".into(),
936 output: Some("def".into()),
937 };
938 assert_eq!(
939 completion.messages[1..],
940 vec![
941 LanguageModelRequestMessage {
942 role: Role::User,
943 content: vec!["abc".into()],
944 cache: false,
945 reasoning_details: None,
946 },
947 LanguageModelRequestMessage {
948 role: Role::Assistant,
949 content: vec![MessageContent::ToolUse(tool_use.clone())],
950 cache: false,
951 reasoning_details: None,
952 },
953 LanguageModelRequestMessage {
954 role: Role::User,
955 content: vec![MessageContent::ToolResult(tool_result.clone())],
956 cache: true,
957 reasoning_details: None,
958 },
959 ]
960 );
961
962 // Simulate reaching tool use limit.
963 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
964 fake_model.end_last_completion_stream();
965 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
966 assert!(
967 last_event
968 .unwrap_err()
969 .is::<language_model::ToolUseLimitReachedError>()
970 );
971
972 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
973 cx.run_until_parked();
974 let completion = fake_model.pending_completions().pop().unwrap();
975 assert_eq!(
976 completion.messages[1..],
977 vec![
978 LanguageModelRequestMessage {
979 role: Role::User,
980 content: vec!["abc".into()],
981 cache: false,
982 reasoning_details: None,
983 },
984 LanguageModelRequestMessage {
985 role: Role::Assistant,
986 content: vec![MessageContent::ToolUse(tool_use)],
987 cache: false,
988 reasoning_details: None,
989 },
990 LanguageModelRequestMessage {
991 role: Role::User,
992 content: vec![MessageContent::ToolResult(tool_result)],
993 cache: false,
994 reasoning_details: None,
995 },
996 LanguageModelRequestMessage {
997 role: Role::User,
998 content: vec!["Continue where you left off".into()],
999 cache: true,
1000 reasoning_details: None,
1001 }
1002 ]
1003 );
1004
1005 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
1006 fake_model.end_last_completion_stream();
1007 events.collect::<Vec<_>>().await;
1008 thread.read_with(cx, |thread, _cx| {
1009 assert_eq!(
1010 thread.last_message().unwrap().to_markdown(),
1011 indoc! {"
1012 ## Assistant
1013
1014 Done
1015 "}
1016 )
1017 });
1018}
1019
1020#[gpui::test]
1021async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
1022 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1023 let fake_model = model.as_fake();
1024
1025 let events = thread
1026 .update(cx, |thread, cx| {
1027 thread.add_tool(EchoTool);
1028 thread.send(UserMessageId::new(), ["abc"], cx)
1029 })
1030 .unwrap();
1031 cx.run_until_parked();
1032
1033 let tool_use = LanguageModelToolUse {
1034 id: "tool_id_1".into(),
1035 name: EchoTool::name().into(),
1036 raw_input: "{}".into(),
1037 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
1038 is_input_complete: true,
1039 thought_signature: None,
1040 };
1041 let tool_result = LanguageModelToolResult {
1042 tool_use_id: "tool_id_1".into(),
1043 tool_name: EchoTool::name().into(),
1044 is_error: false,
1045 content: "def".into(),
1046 output: Some("def".into()),
1047 };
1048 fake_model
1049 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
1050 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
1051 fake_model.end_last_completion_stream();
1052 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
1053 assert!(
1054 last_event
1055 .unwrap_err()
1056 .is::<language_model::ToolUseLimitReachedError>()
1057 );
1058
1059 thread
1060 .update(cx, |thread, cx| {
1061 thread.send(UserMessageId::new(), vec!["ghi"], cx)
1062 })
1063 .unwrap();
1064 cx.run_until_parked();
1065 let completion = fake_model.pending_completions().pop().unwrap();
1066 assert_eq!(
1067 completion.messages[1..],
1068 vec![
1069 LanguageModelRequestMessage {
1070 role: Role::User,
1071 content: vec!["abc".into()],
1072 cache: false,
1073 reasoning_details: None,
1074 },
1075 LanguageModelRequestMessage {
1076 role: Role::Assistant,
1077 content: vec![MessageContent::ToolUse(tool_use)],
1078 cache: false,
1079 reasoning_details: None,
1080 },
1081 LanguageModelRequestMessage {
1082 role: Role::User,
1083 content: vec![MessageContent::ToolResult(tool_result)],
1084 cache: false,
1085 reasoning_details: None,
1086 },
1087 LanguageModelRequestMessage {
1088 role: Role::User,
1089 content: vec!["ghi".into()],
1090 cache: true,
1091 reasoning_details: None,
1092 }
1093 ]
1094 );
1095}
1096
1097async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
1098 let event = events
1099 .next()
1100 .await
1101 .expect("no tool call authorization event received")
1102 .unwrap();
1103 match event {
1104 ThreadEvent::ToolCall(tool_call) => tool_call,
1105 event => {
1106 panic!("Unexpected event {event:?}");
1107 }
1108 }
1109}
1110
1111async fn expect_tool_call_update_fields(
1112 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1113) -> acp::ToolCallUpdate {
1114 let event = events
1115 .next()
1116 .await
1117 .expect("no tool call authorization event received")
1118 .unwrap();
1119 match event {
1120 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
1121 event => {
1122 panic!("Unexpected event {event:?}");
1123 }
1124 }
1125}
1126
1127async fn next_tool_call_authorization(
1128 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1129) -> ToolCallAuthorization {
1130 loop {
1131 let event = events
1132 .next()
1133 .await
1134 .expect("no tool call authorization event received")
1135 .unwrap();
1136 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1137 let permission_kinds = tool_call_authorization
1138 .options
1139 .iter()
1140 .map(|o| o.kind)
1141 .collect::<Vec<_>>();
1142 assert_eq!(
1143 permission_kinds,
1144 vec![
1145 acp::PermissionOptionKind::AllowAlways,
1146 acp::PermissionOptionKind::AllowOnce,
1147 acp::PermissionOptionKind::RejectOnce,
1148 ]
1149 );
1150 return tool_call_authorization;
1151 }
1152 }
1153}
1154
1155#[gpui::test]
1156#[cfg_attr(not(feature = "e2e"), ignore)]
1157async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1158 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1159
1160 // Test concurrent tool calls with different delay times
1161 let events = thread
1162 .update(cx, |thread, cx| {
1163 thread.add_tool(DelayTool);
1164 thread.send(
1165 UserMessageId::new(),
1166 [
1167 "Call the delay tool twice in the same message.",
1168 "Once with 100ms. Once with 300ms.",
1169 "When both timers are complete, describe the outputs.",
1170 ],
1171 cx,
1172 )
1173 })
1174 .unwrap()
1175 .collect()
1176 .await;
1177
1178 let stop_reasons = stop_events(events);
1179 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1180
1181 thread.update(cx, |thread, _cx| {
1182 let last_message = thread.last_message().unwrap();
1183 let agent_message = last_message.as_agent_message().unwrap();
1184 let text = agent_message
1185 .content
1186 .iter()
1187 .filter_map(|content| {
1188 if let AgentMessageContent::Text(text) = content {
1189 Some(text.as_str())
1190 } else {
1191 None
1192 }
1193 })
1194 .collect::<String>();
1195
1196 assert!(text.contains("Ding"));
1197 });
1198}
1199
1200#[gpui::test]
1201async fn test_profiles(cx: &mut TestAppContext) {
1202 let ThreadTest {
1203 model, thread, fs, ..
1204 } = setup(cx, TestModel::Fake).await;
1205 let fake_model = model.as_fake();
1206
1207 thread.update(cx, |thread, _cx| {
1208 thread.add_tool(DelayTool);
1209 thread.add_tool(EchoTool);
1210 thread.add_tool(InfiniteTool);
1211 });
1212
1213 // Override profiles and wait for settings to be loaded.
1214 fs.insert_file(
1215 paths::settings_file(),
1216 json!({
1217 "agent": {
1218 "profiles": {
1219 "test-1": {
1220 "name": "Test Profile 1",
1221 "tools": {
1222 EchoTool::name(): true,
1223 DelayTool::name(): true,
1224 }
1225 },
1226 "test-2": {
1227 "name": "Test Profile 2",
1228 "tools": {
1229 InfiniteTool::name(): true,
1230 }
1231 }
1232 }
1233 }
1234 })
1235 .to_string()
1236 .into_bytes(),
1237 )
1238 .await;
1239 cx.run_until_parked();
1240
1241 // Test that test-1 profile (default) has echo and delay tools
1242 thread
1243 .update(cx, |thread, cx| {
1244 thread.set_profile(AgentProfileId("test-1".into()), cx);
1245 thread.send(UserMessageId::new(), ["test"], cx)
1246 })
1247 .unwrap();
1248 cx.run_until_parked();
1249
1250 let mut pending_completions = fake_model.pending_completions();
1251 assert_eq!(pending_completions.len(), 1);
1252 let completion = pending_completions.pop().unwrap();
1253 let tool_names: Vec<String> = completion
1254 .tools
1255 .iter()
1256 .map(|tool| tool.name.clone())
1257 .collect();
1258 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
1259 fake_model.end_last_completion_stream();
1260
1261 // Switch to test-2 profile, and verify that it has only the infinite tool.
1262 thread
1263 .update(cx, |thread, cx| {
1264 thread.set_profile(AgentProfileId("test-2".into()), cx);
1265 thread.send(UserMessageId::new(), ["test2"], cx)
1266 })
1267 .unwrap();
1268 cx.run_until_parked();
1269 let mut pending_completions = fake_model.pending_completions();
1270 assert_eq!(pending_completions.len(), 1);
1271 let completion = pending_completions.pop().unwrap();
1272 let tool_names: Vec<String> = completion
1273 .tools
1274 .iter()
1275 .map(|tool| tool.name.clone())
1276 .collect();
1277 assert_eq!(tool_names, vec![InfiniteTool::name()]);
1278}
1279
1280#[gpui::test]
1281async fn test_mcp_tools(cx: &mut TestAppContext) {
1282 let ThreadTest {
1283 model,
1284 thread,
1285 context_server_store,
1286 fs,
1287 ..
1288 } = setup(cx, TestModel::Fake).await;
1289 let fake_model = model.as_fake();
1290
1291 // Override profiles and wait for settings to be loaded.
1292 fs.insert_file(
1293 paths::settings_file(),
1294 json!({
1295 "agent": {
1296 "always_allow_tool_actions": true,
1297 "profiles": {
1298 "test": {
1299 "name": "Test Profile",
1300 "enable_all_context_servers": true,
1301 "tools": {
1302 EchoTool::name(): true,
1303 }
1304 },
1305 }
1306 }
1307 })
1308 .to_string()
1309 .into_bytes(),
1310 )
1311 .await;
1312 cx.run_until_parked();
1313 thread.update(cx, |thread, cx| {
1314 thread.set_profile(AgentProfileId("test".into()), cx)
1315 });
1316
1317 let mut mcp_tool_calls = setup_context_server(
1318 "test_server",
1319 vec![context_server::types::Tool {
1320 name: "echo".into(),
1321 description: None,
1322 input_schema: serde_json::to_value(EchoTool::input_schema(
1323 LanguageModelToolSchemaFormat::JsonSchema,
1324 ))
1325 .unwrap(),
1326 output_schema: None,
1327 annotations: None,
1328 }],
1329 &context_server_store,
1330 cx,
1331 );
1332
1333 let events = thread.update(cx, |thread, cx| {
1334 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1335 });
1336 cx.run_until_parked();
1337
1338 // Simulate the model calling the MCP tool.
1339 let completion = fake_model.pending_completions().pop().unwrap();
1340 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1341 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1342 LanguageModelToolUse {
1343 id: "tool_1".into(),
1344 name: "echo".into(),
1345 raw_input: json!({"text": "test"}).to_string(),
1346 input: json!({"text": "test"}),
1347 is_input_complete: true,
1348 thought_signature: None,
1349 },
1350 ));
1351 fake_model.end_last_completion_stream();
1352 cx.run_until_parked();
1353
1354 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1355 assert_eq!(tool_call_params.name, "echo");
1356 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1357 tool_call_response
1358 .send(context_server::types::CallToolResponse {
1359 content: vec![context_server::types::ToolResponseContent::Text {
1360 text: "test".into(),
1361 }],
1362 is_error: None,
1363 meta: None,
1364 structured_content: None,
1365 })
1366 .unwrap();
1367 cx.run_until_parked();
1368
1369 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1370 fake_model.send_last_completion_stream_text_chunk("Done!");
1371 fake_model.end_last_completion_stream();
1372 events.collect::<Vec<_>>().await;
1373
1374 // Send again after adding the echo tool, ensuring the name collision is resolved.
1375 let events = thread.update(cx, |thread, cx| {
1376 thread.add_tool(EchoTool);
1377 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1378 });
1379 cx.run_until_parked();
1380 let completion = fake_model.pending_completions().pop().unwrap();
1381 assert_eq!(
1382 tool_names_for_completion(&completion),
1383 vec!["echo", "test_server_echo"]
1384 );
1385 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1386 LanguageModelToolUse {
1387 id: "tool_2".into(),
1388 name: "test_server_echo".into(),
1389 raw_input: json!({"text": "mcp"}).to_string(),
1390 input: json!({"text": "mcp"}),
1391 is_input_complete: true,
1392 thought_signature: None,
1393 },
1394 ));
1395 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1396 LanguageModelToolUse {
1397 id: "tool_3".into(),
1398 name: "echo".into(),
1399 raw_input: json!({"text": "native"}).to_string(),
1400 input: json!({"text": "native"}),
1401 is_input_complete: true,
1402 thought_signature: None,
1403 },
1404 ));
1405 fake_model.end_last_completion_stream();
1406 cx.run_until_parked();
1407
1408 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1409 assert_eq!(tool_call_params.name, "echo");
1410 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1411 tool_call_response
1412 .send(context_server::types::CallToolResponse {
1413 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1414 is_error: None,
1415 meta: None,
1416 structured_content: None,
1417 })
1418 .unwrap();
1419 cx.run_until_parked();
1420
1421 // Ensure the tool results were inserted with the correct names.
1422 let completion = fake_model.pending_completions().pop().unwrap();
1423 assert_eq!(
1424 completion.messages.last().unwrap().content,
1425 vec![
1426 MessageContent::ToolResult(LanguageModelToolResult {
1427 tool_use_id: "tool_3".into(),
1428 tool_name: "echo".into(),
1429 is_error: false,
1430 content: "native".into(),
1431 output: Some("native".into()),
1432 },),
1433 MessageContent::ToolResult(LanguageModelToolResult {
1434 tool_use_id: "tool_2".into(),
1435 tool_name: "test_server_echo".into(),
1436 is_error: false,
1437 content: "mcp".into(),
1438 output: Some("mcp".into()),
1439 },),
1440 ]
1441 );
1442 fake_model.end_last_completion_stream();
1443 events.collect::<Vec<_>>().await;
1444}
1445
1446#[gpui::test]
1447async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1448 let ThreadTest {
1449 model,
1450 thread,
1451 context_server_store,
1452 fs,
1453 ..
1454 } = setup(cx, TestModel::Fake).await;
1455 let fake_model = model.as_fake();
1456
1457 // Set up a profile with all tools enabled
1458 fs.insert_file(
1459 paths::settings_file(),
1460 json!({
1461 "agent": {
1462 "profiles": {
1463 "test": {
1464 "name": "Test Profile",
1465 "enable_all_context_servers": true,
1466 "tools": {
1467 EchoTool::name(): true,
1468 DelayTool::name(): true,
1469 WordListTool::name(): true,
1470 ToolRequiringPermission::name(): true,
1471 InfiniteTool::name(): true,
1472 }
1473 },
1474 }
1475 }
1476 })
1477 .to_string()
1478 .into_bytes(),
1479 )
1480 .await;
1481 cx.run_until_parked();
1482
1483 thread.update(cx, |thread, cx| {
1484 thread.set_profile(AgentProfileId("test".into()), cx);
1485 thread.add_tool(EchoTool);
1486 thread.add_tool(DelayTool);
1487 thread.add_tool(WordListTool);
1488 thread.add_tool(ToolRequiringPermission);
1489 thread.add_tool(InfiniteTool);
1490 });
1491
1492 // Set up multiple context servers with some overlapping tool names
1493 let _server1_calls = setup_context_server(
1494 "xxx",
1495 vec![
1496 context_server::types::Tool {
1497 name: "echo".into(), // Conflicts with native EchoTool
1498 description: None,
1499 input_schema: serde_json::to_value(EchoTool::input_schema(
1500 LanguageModelToolSchemaFormat::JsonSchema,
1501 ))
1502 .unwrap(),
1503 output_schema: None,
1504 annotations: None,
1505 },
1506 context_server::types::Tool {
1507 name: "unique_tool_1".into(),
1508 description: None,
1509 input_schema: json!({"type": "object", "properties": {}}),
1510 output_schema: None,
1511 annotations: None,
1512 },
1513 ],
1514 &context_server_store,
1515 cx,
1516 );
1517
1518 let _server2_calls = setup_context_server(
1519 "yyy",
1520 vec![
1521 context_server::types::Tool {
1522 name: "echo".into(), // Also conflicts with native EchoTool
1523 description: None,
1524 input_schema: serde_json::to_value(EchoTool::input_schema(
1525 LanguageModelToolSchemaFormat::JsonSchema,
1526 ))
1527 .unwrap(),
1528 output_schema: None,
1529 annotations: None,
1530 },
1531 context_server::types::Tool {
1532 name: "unique_tool_2".into(),
1533 description: None,
1534 input_schema: json!({"type": "object", "properties": {}}),
1535 output_schema: None,
1536 annotations: None,
1537 },
1538 context_server::types::Tool {
1539 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1540 description: None,
1541 input_schema: json!({"type": "object", "properties": {}}),
1542 output_schema: None,
1543 annotations: None,
1544 },
1545 context_server::types::Tool {
1546 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1547 description: None,
1548 input_schema: json!({"type": "object", "properties": {}}),
1549 output_schema: None,
1550 annotations: None,
1551 },
1552 ],
1553 &context_server_store,
1554 cx,
1555 );
1556 let _server3_calls = setup_context_server(
1557 "zzz",
1558 vec![
1559 context_server::types::Tool {
1560 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1561 description: None,
1562 input_schema: json!({"type": "object", "properties": {}}),
1563 output_schema: None,
1564 annotations: None,
1565 },
1566 context_server::types::Tool {
1567 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1568 description: None,
1569 input_schema: json!({"type": "object", "properties": {}}),
1570 output_schema: None,
1571 annotations: None,
1572 },
1573 context_server::types::Tool {
1574 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1575 description: None,
1576 input_schema: json!({"type": "object", "properties": {}}),
1577 output_schema: None,
1578 annotations: None,
1579 },
1580 ],
1581 &context_server_store,
1582 cx,
1583 );
1584
1585 thread
1586 .update(cx, |thread, cx| {
1587 thread.send(UserMessageId::new(), ["Go"], cx)
1588 })
1589 .unwrap();
1590 cx.run_until_parked();
1591 let completion = fake_model.pending_completions().pop().unwrap();
1592 assert_eq!(
1593 tool_names_for_completion(&completion),
1594 vec![
1595 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1596 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1597 "delay",
1598 "echo",
1599 "infinite",
1600 "tool_requiring_permission",
1601 "unique_tool_1",
1602 "unique_tool_2",
1603 "word_list",
1604 "xxx_echo",
1605 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1606 "yyy_echo",
1607 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1608 ]
1609 );
1610}
1611
1612#[gpui::test]
1613#[cfg_attr(not(feature = "e2e"), ignore)]
1614async fn test_cancellation(cx: &mut TestAppContext) {
1615 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1616
1617 let mut events = thread
1618 .update(cx, |thread, cx| {
1619 thread.add_tool(InfiniteTool);
1620 thread.add_tool(EchoTool);
1621 thread.send(
1622 UserMessageId::new(),
1623 ["Call the echo tool, then call the infinite tool, then explain their output"],
1624 cx,
1625 )
1626 })
1627 .unwrap();
1628
1629 // Wait until both tools are called.
1630 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1631 let mut echo_id = None;
1632 let mut echo_completed = false;
1633 while let Some(event) = events.next().await {
1634 match event.unwrap() {
1635 ThreadEvent::ToolCall(tool_call) => {
1636 assert_eq!(tool_call.title, expected_tools.remove(0));
1637 if tool_call.title == "Echo" {
1638 echo_id = Some(tool_call.tool_call_id);
1639 }
1640 }
1641 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1642 acp::ToolCallUpdate {
1643 tool_call_id,
1644 fields:
1645 acp::ToolCallUpdateFields {
1646 status: Some(acp::ToolCallStatus::Completed),
1647 ..
1648 },
1649 ..
1650 },
1651 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1652 echo_completed = true;
1653 }
1654 _ => {}
1655 }
1656
1657 if expected_tools.is_empty() && echo_completed {
1658 break;
1659 }
1660 }
1661
1662 // Cancel the current send and ensure that the event stream is closed, even
1663 // if one of the tools is still running.
1664 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1665 let events = events.collect::<Vec<_>>().await;
1666 let last_event = events.last();
1667 assert!(
1668 matches!(
1669 last_event,
1670 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1671 ),
1672 "unexpected event {last_event:?}"
1673 );
1674
1675 // Ensure we can still send a new message after cancellation.
1676 let events = thread
1677 .update(cx, |thread, cx| {
1678 thread.send(
1679 UserMessageId::new(),
1680 ["Testing: reply with 'Hello' then stop."],
1681 cx,
1682 )
1683 })
1684 .unwrap()
1685 .collect::<Vec<_>>()
1686 .await;
1687 thread.update(cx, |thread, _cx| {
1688 let message = thread.last_message().unwrap();
1689 let agent_message = message.as_agent_message().unwrap();
1690 assert_eq!(
1691 agent_message.content,
1692 vec![AgentMessageContent::Text("Hello".to_string())]
1693 );
1694 });
1695 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1696}
1697
1698#[gpui::test]
1699async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1700 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1701 always_allow_tools(cx);
1702 let fake_model = model.as_fake();
1703
1704 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1705 let environment = Rc::new(FakeThreadEnvironment {
1706 handle: handle.clone(),
1707 });
1708
1709 let mut events = thread
1710 .update(cx, |thread, cx| {
1711 thread.add_tool(crate::TerminalTool::new(
1712 thread.project().clone(),
1713 environment,
1714 ));
1715 thread.send(UserMessageId::new(), ["run a command"], cx)
1716 })
1717 .unwrap();
1718
1719 cx.run_until_parked();
1720
1721 // Simulate the model calling the terminal tool
1722 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1723 LanguageModelToolUse {
1724 id: "terminal_tool_1".into(),
1725 name: "terminal".into(),
1726 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1727 input: json!({"command": "sleep 1000", "cd": "."}),
1728 is_input_complete: true,
1729 thought_signature: None,
1730 },
1731 ));
1732 fake_model.end_last_completion_stream();
1733
1734 // Wait for the terminal tool to start running
1735 wait_for_terminal_tool_started(&mut events, cx).await;
1736
1737 // Cancel the thread while the terminal is running
1738 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1739
1740 // Collect remaining events, driving the executor to let cancellation complete
1741 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1742
1743 // Verify the terminal was killed
1744 assert!(
1745 handle.was_killed(),
1746 "expected terminal handle to be killed on cancellation"
1747 );
1748
1749 // Verify we got a cancellation stop event
1750 assert_eq!(
1751 stop_events(remaining_events),
1752 vec![acp::StopReason::Cancelled],
1753 );
1754
1755 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1756 thread.update(cx, |thread, _cx| {
1757 let message = thread.last_message().unwrap();
1758 let agent_message = message.as_agent_message().unwrap();
1759
1760 let tool_use = agent_message
1761 .content
1762 .iter()
1763 .find_map(|content| match content {
1764 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1765 _ => None,
1766 })
1767 .expect("expected tool use in agent message");
1768
1769 let tool_result = agent_message
1770 .tool_results
1771 .get(&tool_use.id)
1772 .expect("expected tool result");
1773
1774 let result_text = match &tool_result.content {
1775 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
1776 _ => panic!("expected text content in tool result"),
1777 };
1778
1779 // "partial output" comes from FakeTerminalHandle's output field
1780 assert!(
1781 result_text.contains("partial output"),
1782 "expected tool result to contain terminal output, got: {result_text}"
1783 );
1784 // Match the actual format from process_content in terminal_tool.rs
1785 assert!(
1786 result_text.contains("The user stopped this command"),
1787 "expected tool result to indicate user stopped, got: {result_text}"
1788 );
1789 });
1790
1791 // Verify we can send a new message after cancellation
1792 verify_thread_recovery(&thread, &fake_model, cx).await;
1793}
1794
1795#[gpui::test]
1796async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
1797 // This test verifies that tools which properly handle cancellation via
1798 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
1799 // to cancellation and report that they were cancelled.
1800 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1801 always_allow_tools(cx);
1802 let fake_model = model.as_fake();
1803
1804 let (tool, was_cancelled) = CancellationAwareTool::new();
1805
1806 let mut events = thread
1807 .update(cx, |thread, cx| {
1808 thread.add_tool(tool);
1809 thread.send(
1810 UserMessageId::new(),
1811 ["call the cancellation aware tool"],
1812 cx,
1813 )
1814 })
1815 .unwrap();
1816
1817 cx.run_until_parked();
1818
1819 // Simulate the model calling the cancellation-aware tool
1820 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1821 LanguageModelToolUse {
1822 id: "cancellation_aware_1".into(),
1823 name: "cancellation_aware".into(),
1824 raw_input: r#"{}"#.into(),
1825 input: json!({}),
1826 is_input_complete: true,
1827 thought_signature: None,
1828 },
1829 ));
1830 fake_model.end_last_completion_stream();
1831
1832 cx.run_until_parked();
1833
1834 // Wait for the tool call to be reported
1835 let mut tool_started = false;
1836 let deadline = cx.executor().num_cpus() * 100;
1837 for _ in 0..deadline {
1838 cx.run_until_parked();
1839
1840 while let Some(Some(event)) = events.next().now_or_never() {
1841 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
1842 if tool_call.title == "Cancellation Aware Tool" {
1843 tool_started = true;
1844 break;
1845 }
1846 }
1847 }
1848
1849 if tool_started {
1850 break;
1851 }
1852
1853 cx.background_executor
1854 .timer(Duration::from_millis(10))
1855 .await;
1856 }
1857 assert!(tool_started, "expected cancellation aware tool to start");
1858
1859 // Cancel the thread and wait for it to complete
1860 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
1861
1862 // The cancel task should complete promptly because the tool handles cancellation
1863 let timeout = cx.background_executor.timer(Duration::from_secs(5));
1864 futures::select! {
1865 _ = cancel_task.fuse() => {}
1866 _ = timeout.fuse() => {
1867 panic!("cancel task timed out - tool did not respond to cancellation");
1868 }
1869 }
1870
1871 // Verify the tool detected cancellation via its flag
1872 assert!(
1873 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
1874 "tool should have detected cancellation via event_stream.cancelled_by_user()"
1875 );
1876
1877 // Collect remaining events
1878 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1879
1880 // Verify we got a cancellation stop event
1881 assert_eq!(
1882 stop_events(remaining_events),
1883 vec![acp::StopReason::Cancelled],
1884 );
1885
1886 // Verify we can send a new message after cancellation
1887 verify_thread_recovery(&thread, &fake_model, cx).await;
1888}
1889
1890/// Helper to verify thread can recover after cancellation by sending a simple message.
1891async fn verify_thread_recovery(
1892 thread: &Entity<Thread>,
1893 fake_model: &FakeLanguageModel,
1894 cx: &mut TestAppContext,
1895) {
1896 let events = thread
1897 .update(cx, |thread, cx| {
1898 thread.send(
1899 UserMessageId::new(),
1900 ["Testing: reply with 'Hello' then stop."],
1901 cx,
1902 )
1903 })
1904 .unwrap();
1905 cx.run_until_parked();
1906 fake_model.send_last_completion_stream_text_chunk("Hello");
1907 fake_model
1908 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1909 fake_model.end_last_completion_stream();
1910
1911 let events = events.collect::<Vec<_>>().await;
1912 thread.update(cx, |thread, _cx| {
1913 let message = thread.last_message().unwrap();
1914 let agent_message = message.as_agent_message().unwrap();
1915 assert_eq!(
1916 agent_message.content,
1917 vec![AgentMessageContent::Text("Hello".to_string())]
1918 );
1919 });
1920 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1921}
1922
1923/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
1924async fn wait_for_terminal_tool_started(
1925 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1926 cx: &mut TestAppContext,
1927) {
1928 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
1929 for _ in 0..deadline {
1930 cx.run_until_parked();
1931
1932 while let Some(Some(event)) = events.next().now_or_never() {
1933 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1934 update,
1935 ))) = &event
1936 {
1937 if update.fields.content.as_ref().is_some_and(|content| {
1938 content
1939 .iter()
1940 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1941 }) {
1942 return;
1943 }
1944 }
1945 }
1946
1947 cx.background_executor
1948 .timer(Duration::from_millis(10))
1949 .await;
1950 }
1951 panic!("terminal tool did not start within the expected time");
1952}
1953
1954/// Collects events until a Stop event is received, driving the executor to completion.
1955async fn collect_events_until_stop(
1956 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1957 cx: &mut TestAppContext,
1958) -> Vec<Result<ThreadEvent>> {
1959 let mut collected = Vec::new();
1960 let deadline = cx.executor().num_cpus() * 200;
1961
1962 for _ in 0..deadline {
1963 cx.executor().advance_clock(Duration::from_millis(10));
1964 cx.run_until_parked();
1965
1966 while let Some(Some(event)) = events.next().now_or_never() {
1967 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
1968 collected.push(event);
1969 if is_stop {
1970 return collected;
1971 }
1972 }
1973 }
1974 panic!(
1975 "did not receive Stop event within the expected time; collected {} events",
1976 collected.len()
1977 );
1978}
1979
1980#[gpui::test]
1981async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
1982 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1983 always_allow_tools(cx);
1984 let fake_model = model.as_fake();
1985
1986 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1987 let environment = Rc::new(FakeThreadEnvironment {
1988 handle: handle.clone(),
1989 });
1990
1991 let message_id = UserMessageId::new();
1992 let mut events = thread
1993 .update(cx, |thread, cx| {
1994 thread.add_tool(crate::TerminalTool::new(
1995 thread.project().clone(),
1996 environment,
1997 ));
1998 thread.send(message_id.clone(), ["run a command"], cx)
1999 })
2000 .unwrap();
2001
2002 cx.run_until_parked();
2003
2004 // Simulate the model calling the terminal tool
2005 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2006 LanguageModelToolUse {
2007 id: "terminal_tool_1".into(),
2008 name: "terminal".into(),
2009 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2010 input: json!({"command": "sleep 1000", "cd": "."}),
2011 is_input_complete: true,
2012 thought_signature: None,
2013 },
2014 ));
2015 fake_model.end_last_completion_stream();
2016
2017 // Wait for the terminal tool to start running
2018 wait_for_terminal_tool_started(&mut events, cx).await;
2019
2020 // Truncate the thread while the terminal is running
2021 thread
2022 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2023 .unwrap();
2024
2025 // Drive the executor to let cancellation complete
2026 let _ = collect_events_until_stop(&mut events, cx).await;
2027
2028 // Verify the terminal was killed
2029 assert!(
2030 handle.was_killed(),
2031 "expected terminal handle to be killed on truncate"
2032 );
2033
2034 // Verify the thread is empty after truncation
2035 thread.update(cx, |thread, _cx| {
2036 assert_eq!(
2037 thread.to_markdown(),
2038 "",
2039 "expected thread to be empty after truncating the only message"
2040 );
2041 });
2042
2043 // Verify we can send a new message after truncation
2044 verify_thread_recovery(&thread, &fake_model, cx).await;
2045}
2046
2047#[gpui::test]
2048async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2049 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2050 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2051 always_allow_tools(cx);
2052 let fake_model = model.as_fake();
2053
2054 let environment = Rc::new(MultiTerminalEnvironment::new());
2055
2056 let mut events = thread
2057 .update(cx, |thread, cx| {
2058 thread.add_tool(crate::TerminalTool::new(
2059 thread.project().clone(),
2060 environment.clone(),
2061 ));
2062 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2063 })
2064 .unwrap();
2065
2066 cx.run_until_parked();
2067
2068 // Simulate the model calling two terminal tools
2069 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2070 LanguageModelToolUse {
2071 id: "terminal_tool_1".into(),
2072 name: "terminal".into(),
2073 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2074 input: json!({"command": "sleep 1000", "cd": "."}),
2075 is_input_complete: true,
2076 thought_signature: None,
2077 },
2078 ));
2079 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2080 LanguageModelToolUse {
2081 id: "terminal_tool_2".into(),
2082 name: "terminal".into(),
2083 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2084 input: json!({"command": "sleep 2000", "cd": "."}),
2085 is_input_complete: true,
2086 thought_signature: None,
2087 },
2088 ));
2089 fake_model.end_last_completion_stream();
2090
2091 // Wait for both terminal tools to start by counting terminal content updates
2092 let mut terminals_started = 0;
2093 let deadline = cx.executor().num_cpus() * 100;
2094 for _ in 0..deadline {
2095 cx.run_until_parked();
2096
2097 while let Some(Some(event)) = events.next().now_or_never() {
2098 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2099 update,
2100 ))) = &event
2101 {
2102 if update.fields.content.as_ref().is_some_and(|content| {
2103 content
2104 .iter()
2105 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2106 }) {
2107 terminals_started += 1;
2108 if terminals_started >= 2 {
2109 break;
2110 }
2111 }
2112 }
2113 }
2114 if terminals_started >= 2 {
2115 break;
2116 }
2117
2118 cx.background_executor
2119 .timer(Duration::from_millis(10))
2120 .await;
2121 }
2122 assert!(
2123 terminals_started >= 2,
2124 "expected 2 terminal tools to start, got {terminals_started}"
2125 );
2126
2127 // Cancel the thread while both terminals are running
2128 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2129
2130 // Collect remaining events
2131 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2132
2133 // Verify both terminal handles were killed
2134 let handles = environment.handles();
2135 assert_eq!(
2136 handles.len(),
2137 2,
2138 "expected 2 terminal handles to be created"
2139 );
2140 assert!(
2141 handles[0].was_killed(),
2142 "expected first terminal handle to be killed on cancellation"
2143 );
2144 assert!(
2145 handles[1].was_killed(),
2146 "expected second terminal handle to be killed on cancellation"
2147 );
2148
2149 // Verify we got a cancellation stop event
2150 assert_eq!(
2151 stop_events(remaining_events),
2152 vec![acp::StopReason::Cancelled],
2153 );
2154}
2155
2156#[gpui::test]
2157async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2158 // Tests that clicking the stop button on the terminal card (as opposed to the main
2159 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2160 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2161 always_allow_tools(cx);
2162 let fake_model = model.as_fake();
2163
2164 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2165 let environment = Rc::new(FakeThreadEnvironment {
2166 handle: handle.clone(),
2167 });
2168
2169 let mut events = thread
2170 .update(cx, |thread, cx| {
2171 thread.add_tool(crate::TerminalTool::new(
2172 thread.project().clone(),
2173 environment,
2174 ));
2175 thread.send(UserMessageId::new(), ["run a command"], cx)
2176 })
2177 .unwrap();
2178
2179 cx.run_until_parked();
2180
2181 // Simulate the model calling the terminal tool
2182 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2183 LanguageModelToolUse {
2184 id: "terminal_tool_1".into(),
2185 name: "terminal".into(),
2186 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2187 input: json!({"command": "sleep 1000", "cd": "."}),
2188 is_input_complete: true,
2189 thought_signature: None,
2190 },
2191 ));
2192 fake_model.end_last_completion_stream();
2193
2194 // Wait for the terminal tool to start running
2195 wait_for_terminal_tool_started(&mut events, cx).await;
2196
2197 // Simulate user clicking stop on the terminal card itself.
2198 // This sets the flag and signals exit (simulating what the real UI would do).
2199 handle.set_stopped_by_user(true);
2200 handle.killed.store(true, Ordering::SeqCst);
2201 handle.signal_exit();
2202
2203 // Wait for the tool to complete
2204 cx.run_until_parked();
2205
2206 // The thread continues after tool completion - simulate the model ending its turn
2207 fake_model
2208 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2209 fake_model.end_last_completion_stream();
2210
2211 // Collect remaining events
2212 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2213
2214 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2215 assert_eq!(
2216 stop_events(remaining_events),
2217 vec![acp::StopReason::EndTurn],
2218 );
2219
2220 // Verify the tool result indicates user stopped
2221 thread.update(cx, |thread, _cx| {
2222 let message = thread.last_message().unwrap();
2223 let agent_message = message.as_agent_message().unwrap();
2224
2225 let tool_use = agent_message
2226 .content
2227 .iter()
2228 .find_map(|content| match content {
2229 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2230 _ => None,
2231 })
2232 .expect("expected tool use in agent message");
2233
2234 let tool_result = agent_message
2235 .tool_results
2236 .get(&tool_use.id)
2237 .expect("expected tool result");
2238
2239 let result_text = match &tool_result.content {
2240 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2241 _ => panic!("expected text content in tool result"),
2242 };
2243
2244 assert!(
2245 result_text.contains("The user stopped this command"),
2246 "expected tool result to indicate user stopped, got: {result_text}"
2247 );
2248 });
2249}
2250
2251#[gpui::test]
2252async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2253 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2254 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2255 always_allow_tools(cx);
2256 let fake_model = model.as_fake();
2257
2258 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2259 let environment = Rc::new(FakeThreadEnvironment {
2260 handle: handle.clone(),
2261 });
2262
2263 let mut events = thread
2264 .update(cx, |thread, cx| {
2265 thread.add_tool(crate::TerminalTool::new(
2266 thread.project().clone(),
2267 environment,
2268 ));
2269 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2270 })
2271 .unwrap();
2272
2273 cx.run_until_parked();
2274
2275 // Simulate the model calling the terminal tool with a short timeout
2276 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2277 LanguageModelToolUse {
2278 id: "terminal_tool_1".into(),
2279 name: "terminal".into(),
2280 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2281 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2282 is_input_complete: true,
2283 thought_signature: None,
2284 },
2285 ));
2286 fake_model.end_last_completion_stream();
2287
2288 // Wait for the terminal tool to start running
2289 wait_for_terminal_tool_started(&mut events, cx).await;
2290
2291 // Advance clock past the timeout
2292 cx.executor().advance_clock(Duration::from_millis(200));
2293 cx.run_until_parked();
2294
2295 // The thread continues after tool completion - simulate the model ending its turn
2296 fake_model
2297 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2298 fake_model.end_last_completion_stream();
2299
2300 // Collect remaining events
2301 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2302
2303 // Verify the terminal was killed due to timeout
2304 assert!(
2305 handle.was_killed(),
2306 "expected terminal handle to be killed on timeout"
2307 );
2308
2309 // Verify we got an EndTurn (the tool completed, just with timeout)
2310 assert_eq!(
2311 stop_events(remaining_events),
2312 vec![acp::StopReason::EndTurn],
2313 );
2314
2315 // Verify the tool result indicates timeout, not user stopped
2316 thread.update(cx, |thread, _cx| {
2317 let message = thread.last_message().unwrap();
2318 let agent_message = message.as_agent_message().unwrap();
2319
2320 let tool_use = agent_message
2321 .content
2322 .iter()
2323 .find_map(|content| match content {
2324 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2325 _ => None,
2326 })
2327 .expect("expected tool use in agent message");
2328
2329 let tool_result = agent_message
2330 .tool_results
2331 .get(&tool_use.id)
2332 .expect("expected tool result");
2333
2334 let result_text = match &tool_result.content {
2335 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2336 _ => panic!("expected text content in tool result"),
2337 };
2338
2339 assert!(
2340 result_text.contains("timed out"),
2341 "expected tool result to indicate timeout, got: {result_text}"
2342 );
2343 assert!(
2344 !result_text.contains("The user stopped"),
2345 "tool result should not mention user stopped when it timed out, got: {result_text}"
2346 );
2347 });
2348}
2349
2350#[gpui::test]
2351async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2352 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2353 let fake_model = model.as_fake();
2354
2355 let events_1 = thread
2356 .update(cx, |thread, cx| {
2357 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2358 })
2359 .unwrap();
2360 cx.run_until_parked();
2361 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2362 cx.run_until_parked();
2363
2364 let events_2 = thread
2365 .update(cx, |thread, cx| {
2366 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2367 })
2368 .unwrap();
2369 cx.run_until_parked();
2370 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2371 fake_model
2372 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2373 fake_model.end_last_completion_stream();
2374
2375 let events_1 = events_1.collect::<Vec<_>>().await;
2376 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2377 let events_2 = events_2.collect::<Vec<_>>().await;
2378 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2379}
2380
2381#[gpui::test]
2382async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2383 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2384 let fake_model = model.as_fake();
2385
2386 let events_1 = thread
2387 .update(cx, |thread, cx| {
2388 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2389 })
2390 .unwrap();
2391 cx.run_until_parked();
2392 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2393 fake_model
2394 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2395 fake_model.end_last_completion_stream();
2396 let events_1 = events_1.collect::<Vec<_>>().await;
2397
2398 let events_2 = thread
2399 .update(cx, |thread, cx| {
2400 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2401 })
2402 .unwrap();
2403 cx.run_until_parked();
2404 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2405 fake_model
2406 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2407 fake_model.end_last_completion_stream();
2408 let events_2 = events_2.collect::<Vec<_>>().await;
2409
2410 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2411 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2412}
2413
2414#[gpui::test]
2415async fn test_refusal(cx: &mut TestAppContext) {
2416 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2417 let fake_model = model.as_fake();
2418
2419 let events = thread
2420 .update(cx, |thread, cx| {
2421 thread.send(UserMessageId::new(), ["Hello"], cx)
2422 })
2423 .unwrap();
2424 cx.run_until_parked();
2425 thread.read_with(cx, |thread, _| {
2426 assert_eq!(
2427 thread.to_markdown(),
2428 indoc! {"
2429 ## User
2430
2431 Hello
2432 "}
2433 );
2434 });
2435
2436 fake_model.send_last_completion_stream_text_chunk("Hey!");
2437 cx.run_until_parked();
2438 thread.read_with(cx, |thread, _| {
2439 assert_eq!(
2440 thread.to_markdown(),
2441 indoc! {"
2442 ## User
2443
2444 Hello
2445
2446 ## Assistant
2447
2448 Hey!
2449 "}
2450 );
2451 });
2452
2453 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2454 fake_model
2455 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2456 let events = events.collect::<Vec<_>>().await;
2457 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2458 thread.read_with(cx, |thread, _| {
2459 assert_eq!(thread.to_markdown(), "");
2460 });
2461}
2462
2463#[gpui::test]
2464async fn test_truncate_first_message(cx: &mut TestAppContext) {
2465 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2466 let fake_model = model.as_fake();
2467
2468 let message_id = UserMessageId::new();
2469 thread
2470 .update(cx, |thread, cx| {
2471 thread.send(message_id.clone(), ["Hello"], cx)
2472 })
2473 .unwrap();
2474 cx.run_until_parked();
2475 thread.read_with(cx, |thread, _| {
2476 assert_eq!(
2477 thread.to_markdown(),
2478 indoc! {"
2479 ## User
2480
2481 Hello
2482 "}
2483 );
2484 assert_eq!(thread.latest_token_usage(), None);
2485 });
2486
2487 fake_model.send_last_completion_stream_text_chunk("Hey!");
2488 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2489 language_model::TokenUsage {
2490 input_tokens: 32_000,
2491 output_tokens: 16_000,
2492 cache_creation_input_tokens: 0,
2493 cache_read_input_tokens: 0,
2494 },
2495 ));
2496 cx.run_until_parked();
2497 thread.read_with(cx, |thread, _| {
2498 assert_eq!(
2499 thread.to_markdown(),
2500 indoc! {"
2501 ## User
2502
2503 Hello
2504
2505 ## Assistant
2506
2507 Hey!
2508 "}
2509 );
2510 assert_eq!(
2511 thread.latest_token_usage(),
2512 Some(acp_thread::TokenUsage {
2513 used_tokens: 32_000 + 16_000,
2514 max_tokens: 1_000_000,
2515 output_tokens: 16_000,
2516 })
2517 );
2518 });
2519
2520 thread
2521 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2522 .unwrap();
2523 cx.run_until_parked();
2524 thread.read_with(cx, |thread, _| {
2525 assert_eq!(thread.to_markdown(), "");
2526 assert_eq!(thread.latest_token_usage(), None);
2527 });
2528
2529 // Ensure we can still send a new message after truncation.
2530 thread
2531 .update(cx, |thread, cx| {
2532 thread.send(UserMessageId::new(), ["Hi"], cx)
2533 })
2534 .unwrap();
2535 thread.update(cx, |thread, _cx| {
2536 assert_eq!(
2537 thread.to_markdown(),
2538 indoc! {"
2539 ## User
2540
2541 Hi
2542 "}
2543 );
2544 });
2545 cx.run_until_parked();
2546 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2547 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2548 language_model::TokenUsage {
2549 input_tokens: 40_000,
2550 output_tokens: 20_000,
2551 cache_creation_input_tokens: 0,
2552 cache_read_input_tokens: 0,
2553 },
2554 ));
2555 cx.run_until_parked();
2556 thread.read_with(cx, |thread, _| {
2557 assert_eq!(
2558 thread.to_markdown(),
2559 indoc! {"
2560 ## User
2561
2562 Hi
2563
2564 ## Assistant
2565
2566 Ahoy!
2567 "}
2568 );
2569
2570 assert_eq!(
2571 thread.latest_token_usage(),
2572 Some(acp_thread::TokenUsage {
2573 used_tokens: 40_000 + 20_000,
2574 max_tokens: 1_000_000,
2575 output_tokens: 20_000,
2576 })
2577 );
2578 });
2579}
2580
2581#[gpui::test]
2582async fn test_truncate_second_message(cx: &mut TestAppContext) {
2583 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2584 let fake_model = model.as_fake();
2585
2586 thread
2587 .update(cx, |thread, cx| {
2588 thread.send(UserMessageId::new(), ["Message 1"], cx)
2589 })
2590 .unwrap();
2591 cx.run_until_parked();
2592 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2593 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2594 language_model::TokenUsage {
2595 input_tokens: 32_000,
2596 output_tokens: 16_000,
2597 cache_creation_input_tokens: 0,
2598 cache_read_input_tokens: 0,
2599 },
2600 ));
2601 fake_model.end_last_completion_stream();
2602 cx.run_until_parked();
2603
2604 let assert_first_message_state = |cx: &mut TestAppContext| {
2605 thread.clone().read_with(cx, |thread, _| {
2606 assert_eq!(
2607 thread.to_markdown(),
2608 indoc! {"
2609 ## User
2610
2611 Message 1
2612
2613 ## Assistant
2614
2615 Message 1 response
2616 "}
2617 );
2618
2619 assert_eq!(
2620 thread.latest_token_usage(),
2621 Some(acp_thread::TokenUsage {
2622 used_tokens: 32_000 + 16_000,
2623 max_tokens: 1_000_000,
2624 output_tokens: 16_000,
2625 })
2626 );
2627 });
2628 };
2629
2630 assert_first_message_state(cx);
2631
2632 let second_message_id = UserMessageId::new();
2633 thread
2634 .update(cx, |thread, cx| {
2635 thread.send(second_message_id.clone(), ["Message 2"], cx)
2636 })
2637 .unwrap();
2638 cx.run_until_parked();
2639
2640 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2641 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2642 language_model::TokenUsage {
2643 input_tokens: 40_000,
2644 output_tokens: 20_000,
2645 cache_creation_input_tokens: 0,
2646 cache_read_input_tokens: 0,
2647 },
2648 ));
2649 fake_model.end_last_completion_stream();
2650 cx.run_until_parked();
2651
2652 thread.read_with(cx, |thread, _| {
2653 assert_eq!(
2654 thread.to_markdown(),
2655 indoc! {"
2656 ## User
2657
2658 Message 1
2659
2660 ## Assistant
2661
2662 Message 1 response
2663
2664 ## User
2665
2666 Message 2
2667
2668 ## Assistant
2669
2670 Message 2 response
2671 "}
2672 );
2673
2674 assert_eq!(
2675 thread.latest_token_usage(),
2676 Some(acp_thread::TokenUsage {
2677 used_tokens: 40_000 + 20_000,
2678 max_tokens: 1_000_000,
2679 output_tokens: 20_000,
2680 })
2681 );
2682 });
2683
2684 thread
2685 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2686 .unwrap();
2687 cx.run_until_parked();
2688
2689 assert_first_message_state(cx);
2690}
2691
2692#[gpui::test]
2693async fn test_title_generation(cx: &mut TestAppContext) {
2694 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2695 let fake_model = model.as_fake();
2696
2697 let summary_model = Arc::new(FakeLanguageModel::default());
2698 thread.update(cx, |thread, cx| {
2699 thread.set_summarization_model(Some(summary_model.clone()), cx)
2700 });
2701
2702 let send = thread
2703 .update(cx, |thread, cx| {
2704 thread.send(UserMessageId::new(), ["Hello"], cx)
2705 })
2706 .unwrap();
2707 cx.run_until_parked();
2708
2709 fake_model.send_last_completion_stream_text_chunk("Hey!");
2710 fake_model.end_last_completion_stream();
2711 cx.run_until_parked();
2712 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2713
2714 // Ensure the summary model has been invoked to generate a title.
2715 summary_model.send_last_completion_stream_text_chunk("Hello ");
2716 summary_model.send_last_completion_stream_text_chunk("world\nG");
2717 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2718 summary_model.end_last_completion_stream();
2719 send.collect::<Vec<_>>().await;
2720 cx.run_until_parked();
2721 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2722
2723 // Send another message, ensuring no title is generated this time.
2724 let send = thread
2725 .update(cx, |thread, cx| {
2726 thread.send(UserMessageId::new(), ["Hello again"], cx)
2727 })
2728 .unwrap();
2729 cx.run_until_parked();
2730 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2731 fake_model.end_last_completion_stream();
2732 cx.run_until_parked();
2733 assert_eq!(summary_model.pending_completions(), Vec::new());
2734 send.collect::<Vec<_>>().await;
2735 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2736}
2737
2738#[gpui::test]
2739async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2740 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2741 let fake_model = model.as_fake();
2742
2743 let _events = thread
2744 .update(cx, |thread, cx| {
2745 thread.add_tool(ToolRequiringPermission);
2746 thread.add_tool(EchoTool);
2747 thread.send(UserMessageId::new(), ["Hey!"], cx)
2748 })
2749 .unwrap();
2750 cx.run_until_parked();
2751
2752 let permission_tool_use = LanguageModelToolUse {
2753 id: "tool_id_1".into(),
2754 name: ToolRequiringPermission::name().into(),
2755 raw_input: "{}".into(),
2756 input: json!({}),
2757 is_input_complete: true,
2758 thought_signature: None,
2759 };
2760 let echo_tool_use = LanguageModelToolUse {
2761 id: "tool_id_2".into(),
2762 name: EchoTool::name().into(),
2763 raw_input: json!({"text": "test"}).to_string(),
2764 input: json!({"text": "test"}),
2765 is_input_complete: true,
2766 thought_signature: None,
2767 };
2768 fake_model.send_last_completion_stream_text_chunk("Hi!");
2769 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2770 permission_tool_use,
2771 ));
2772 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2773 echo_tool_use.clone(),
2774 ));
2775 fake_model.end_last_completion_stream();
2776 cx.run_until_parked();
2777
2778 // Ensure pending tools are skipped when building a request.
2779 let request = thread
2780 .read_with(cx, |thread, cx| {
2781 thread.build_completion_request(CompletionIntent::EditFile, cx)
2782 })
2783 .unwrap();
2784 assert_eq!(
2785 request.messages[1..],
2786 vec![
2787 LanguageModelRequestMessage {
2788 role: Role::User,
2789 content: vec!["Hey!".into()],
2790 cache: true,
2791 reasoning_details: None,
2792 },
2793 LanguageModelRequestMessage {
2794 role: Role::Assistant,
2795 content: vec![
2796 MessageContent::Text("Hi!".into()),
2797 MessageContent::ToolUse(echo_tool_use.clone())
2798 ],
2799 cache: false,
2800 reasoning_details: None,
2801 },
2802 LanguageModelRequestMessage {
2803 role: Role::User,
2804 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2805 tool_use_id: echo_tool_use.id.clone(),
2806 tool_name: echo_tool_use.name,
2807 is_error: false,
2808 content: "test".into(),
2809 output: Some("test".into())
2810 })],
2811 cache: false,
2812 reasoning_details: None,
2813 },
2814 ],
2815 );
2816}
2817
2818#[gpui::test]
2819async fn test_agent_connection(cx: &mut TestAppContext) {
2820 cx.update(settings::init);
2821 let templates = Templates::new();
2822
2823 // Initialize language model system with test provider
2824 cx.update(|cx| {
2825 gpui_tokio::init(cx);
2826
2827 let http_client = FakeHttpClient::with_404_response();
2828 let clock = Arc::new(clock::FakeSystemClock::new());
2829 let client = Client::new(clock, http_client, cx);
2830 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2831 language_model::init(client.clone(), cx);
2832 language_models::init(user_store, client.clone(), cx);
2833 LanguageModelRegistry::test(cx);
2834 });
2835 cx.executor().forbid_parking();
2836
2837 // Create a project for new_thread
2838 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2839 fake_fs.insert_tree(path!("/test"), json!({})).await;
2840 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2841 let cwd = Path::new("/test");
2842 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2843
2844 // Create agent and connection
2845 let agent = NativeAgent::new(
2846 project.clone(),
2847 thread_store,
2848 templates.clone(),
2849 None,
2850 fake_fs.clone(),
2851 &mut cx.to_async(),
2852 )
2853 .await
2854 .unwrap();
2855 let connection = NativeAgentConnection(agent.clone());
2856
2857 // Create a thread using new_thread
2858 let connection_rc = Rc::new(connection.clone());
2859 let acp_thread = cx
2860 .update(|cx| connection_rc.new_thread(project, cwd, cx))
2861 .await
2862 .expect("new_thread should succeed");
2863
2864 // Get the session_id from the AcpThread
2865 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2866
2867 // Test model_selector returns Some
2868 let selector_opt = connection.model_selector(&session_id);
2869 assert!(
2870 selector_opt.is_some(),
2871 "agent should always support ModelSelector"
2872 );
2873 let selector = selector_opt.unwrap();
2874
2875 // Test list_models
2876 let listed_models = cx
2877 .update(|cx| selector.list_models(cx))
2878 .await
2879 .expect("list_models should succeed");
2880 let AgentModelList::Grouped(listed_models) = listed_models else {
2881 panic!("Unexpected model list type");
2882 };
2883 assert!(!listed_models.is_empty(), "should have at least one model");
2884 assert_eq!(
2885 listed_models[&AgentModelGroupName("Fake".into())][0]
2886 .id
2887 .0
2888 .as_ref(),
2889 "fake/fake"
2890 );
2891
2892 // Test selected_model returns the default
2893 let model = cx
2894 .update(|cx| selector.selected_model(cx))
2895 .await
2896 .expect("selected_model should succeed");
2897 let model = cx
2898 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2899 .unwrap();
2900 let model = model.as_fake();
2901 assert_eq!(model.id().0, "fake", "should return default model");
2902
2903 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2904 cx.run_until_parked();
2905 model.send_last_completion_stream_text_chunk("def");
2906 cx.run_until_parked();
2907 acp_thread.read_with(cx, |thread, cx| {
2908 assert_eq!(
2909 thread.to_markdown(cx),
2910 indoc! {"
2911 ## User
2912
2913 abc
2914
2915 ## Assistant
2916
2917 def
2918
2919 "}
2920 )
2921 });
2922
2923 // Test cancel
2924 cx.update(|cx| connection.cancel(&session_id, cx));
2925 request.await.expect("prompt should fail gracefully");
2926
2927 // Ensure that dropping the ACP thread causes the native thread to be
2928 // dropped as well.
2929 cx.update(|_| drop(acp_thread));
2930 let result = cx
2931 .update(|cx| {
2932 connection.prompt(
2933 Some(acp_thread::UserMessageId::new()),
2934 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2935 cx,
2936 )
2937 })
2938 .await;
2939 assert_eq!(
2940 result.as_ref().unwrap_err().to_string(),
2941 "Session not found",
2942 "unexpected result: {:?}",
2943 result
2944 );
2945}
2946
2947#[gpui::test]
2948async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2949 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2950 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2951 let fake_model = model.as_fake();
2952
2953 let mut events = thread
2954 .update(cx, |thread, cx| {
2955 thread.send(UserMessageId::new(), ["Think"], cx)
2956 })
2957 .unwrap();
2958 cx.run_until_parked();
2959
2960 // Simulate streaming partial input.
2961 let input = json!({});
2962 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2963 LanguageModelToolUse {
2964 id: "1".into(),
2965 name: ThinkingTool::name().into(),
2966 raw_input: input.to_string(),
2967 input,
2968 is_input_complete: false,
2969 thought_signature: None,
2970 },
2971 ));
2972
2973 // Input streaming completed
2974 let input = json!({ "content": "Thinking hard!" });
2975 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2976 LanguageModelToolUse {
2977 id: "1".into(),
2978 name: "thinking".into(),
2979 raw_input: input.to_string(),
2980 input,
2981 is_input_complete: true,
2982 thought_signature: None,
2983 },
2984 ));
2985 fake_model.end_last_completion_stream();
2986 cx.run_until_parked();
2987
2988 let tool_call = expect_tool_call(&mut events).await;
2989 assert_eq!(
2990 tool_call,
2991 acp::ToolCall::new("1", "Thinking")
2992 .kind(acp::ToolKind::Think)
2993 .raw_input(json!({}))
2994 .meta(acp::Meta::from_iter([(
2995 "tool_name".into(),
2996 "thinking".into()
2997 )]))
2998 );
2999 let update = expect_tool_call_update_fields(&mut events).await;
3000 assert_eq!(
3001 update,
3002 acp::ToolCallUpdate::new(
3003 "1",
3004 acp::ToolCallUpdateFields::new()
3005 .title("Thinking")
3006 .kind(acp::ToolKind::Think)
3007 .raw_input(json!({ "content": "Thinking hard!"}))
3008 )
3009 );
3010 let update = expect_tool_call_update_fields(&mut events).await;
3011 assert_eq!(
3012 update,
3013 acp::ToolCallUpdate::new(
3014 "1",
3015 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3016 )
3017 );
3018 let update = expect_tool_call_update_fields(&mut events).await;
3019 assert_eq!(
3020 update,
3021 acp::ToolCallUpdate::new(
3022 "1",
3023 acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
3024 )
3025 );
3026 let update = expect_tool_call_update_fields(&mut events).await;
3027 assert_eq!(
3028 update,
3029 acp::ToolCallUpdate::new(
3030 "1",
3031 acp::ToolCallUpdateFields::new()
3032 .status(acp::ToolCallStatus::Completed)
3033 .raw_output("Finished thinking.")
3034 )
3035 );
3036}
3037
3038#[gpui::test]
3039async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3040 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3041 let fake_model = model.as_fake();
3042
3043 let mut events = thread
3044 .update(cx, |thread, cx| {
3045 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3046 thread.send(UserMessageId::new(), ["Hello!"], cx)
3047 })
3048 .unwrap();
3049 cx.run_until_parked();
3050
3051 fake_model.send_last_completion_stream_text_chunk("Hey!");
3052 fake_model.end_last_completion_stream();
3053
3054 let mut retry_events = Vec::new();
3055 while let Some(Ok(event)) = events.next().await {
3056 match event {
3057 ThreadEvent::Retry(retry_status) => {
3058 retry_events.push(retry_status);
3059 }
3060 ThreadEvent::Stop(..) => break,
3061 _ => {}
3062 }
3063 }
3064
3065 assert_eq!(retry_events.len(), 0);
3066 thread.read_with(cx, |thread, _cx| {
3067 assert_eq!(
3068 thread.to_markdown(),
3069 indoc! {"
3070 ## User
3071
3072 Hello!
3073
3074 ## Assistant
3075
3076 Hey!
3077 "}
3078 )
3079 });
3080}
3081
3082#[gpui::test]
3083async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3084 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3085 let fake_model = model.as_fake();
3086
3087 let mut events = thread
3088 .update(cx, |thread, cx| {
3089 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3090 thread.send(UserMessageId::new(), ["Hello!"], cx)
3091 })
3092 .unwrap();
3093 cx.run_until_parked();
3094
3095 fake_model.send_last_completion_stream_text_chunk("Hey,");
3096 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3097 provider: LanguageModelProviderName::new("Anthropic"),
3098 retry_after: Some(Duration::from_secs(3)),
3099 });
3100 fake_model.end_last_completion_stream();
3101
3102 cx.executor().advance_clock(Duration::from_secs(3));
3103 cx.run_until_parked();
3104
3105 fake_model.send_last_completion_stream_text_chunk("there!");
3106 fake_model.end_last_completion_stream();
3107 cx.run_until_parked();
3108
3109 let mut retry_events = Vec::new();
3110 while let Some(Ok(event)) = events.next().await {
3111 match event {
3112 ThreadEvent::Retry(retry_status) => {
3113 retry_events.push(retry_status);
3114 }
3115 ThreadEvent::Stop(..) => break,
3116 _ => {}
3117 }
3118 }
3119
3120 assert_eq!(retry_events.len(), 1);
3121 assert!(matches!(
3122 retry_events[0],
3123 acp_thread::RetryStatus { attempt: 1, .. }
3124 ));
3125 thread.read_with(cx, |thread, _cx| {
3126 assert_eq!(
3127 thread.to_markdown(),
3128 indoc! {"
3129 ## User
3130
3131 Hello!
3132
3133 ## Assistant
3134
3135 Hey,
3136
3137 [resume]
3138
3139 ## Assistant
3140
3141 there!
3142 "}
3143 )
3144 });
3145}
3146
3147#[gpui::test]
3148async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3149 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3150 let fake_model = model.as_fake();
3151
3152 let events = thread
3153 .update(cx, |thread, cx| {
3154 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3155 thread.add_tool(EchoTool);
3156 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3157 })
3158 .unwrap();
3159 cx.run_until_parked();
3160
3161 let tool_use_1 = LanguageModelToolUse {
3162 id: "tool_1".into(),
3163 name: EchoTool::name().into(),
3164 raw_input: json!({"text": "test"}).to_string(),
3165 input: json!({"text": "test"}),
3166 is_input_complete: true,
3167 thought_signature: None,
3168 };
3169 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3170 tool_use_1.clone(),
3171 ));
3172 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3173 provider: LanguageModelProviderName::new("Anthropic"),
3174 retry_after: Some(Duration::from_secs(3)),
3175 });
3176 fake_model.end_last_completion_stream();
3177
3178 cx.executor().advance_clock(Duration::from_secs(3));
3179 let completion = fake_model.pending_completions().pop().unwrap();
3180 assert_eq!(
3181 completion.messages[1..],
3182 vec![
3183 LanguageModelRequestMessage {
3184 role: Role::User,
3185 content: vec!["Call the echo tool!".into()],
3186 cache: false,
3187 reasoning_details: None,
3188 },
3189 LanguageModelRequestMessage {
3190 role: Role::Assistant,
3191 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3192 cache: false,
3193 reasoning_details: None,
3194 },
3195 LanguageModelRequestMessage {
3196 role: Role::User,
3197 content: vec![language_model::MessageContent::ToolResult(
3198 LanguageModelToolResult {
3199 tool_use_id: tool_use_1.id.clone(),
3200 tool_name: tool_use_1.name.clone(),
3201 is_error: false,
3202 content: "test".into(),
3203 output: Some("test".into())
3204 }
3205 )],
3206 cache: true,
3207 reasoning_details: None,
3208 },
3209 ]
3210 );
3211
3212 fake_model.send_last_completion_stream_text_chunk("Done");
3213 fake_model.end_last_completion_stream();
3214 cx.run_until_parked();
3215 events.collect::<Vec<_>>().await;
3216 thread.read_with(cx, |thread, _cx| {
3217 assert_eq!(
3218 thread.last_message(),
3219 Some(Message::Agent(AgentMessage {
3220 content: vec![AgentMessageContent::Text("Done".into())],
3221 tool_results: IndexMap::default(),
3222 reasoning_details: None,
3223 }))
3224 );
3225 })
3226}
3227
3228#[gpui::test]
3229async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3230 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3231 let fake_model = model.as_fake();
3232
3233 let mut events = thread
3234 .update(cx, |thread, cx| {
3235 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3236 thread.send(UserMessageId::new(), ["Hello!"], cx)
3237 })
3238 .unwrap();
3239 cx.run_until_parked();
3240
3241 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3242 fake_model.send_last_completion_stream_error(
3243 LanguageModelCompletionError::ServerOverloaded {
3244 provider: LanguageModelProviderName::new("Anthropic"),
3245 retry_after: Some(Duration::from_secs(3)),
3246 },
3247 );
3248 fake_model.end_last_completion_stream();
3249 cx.executor().advance_clock(Duration::from_secs(3));
3250 cx.run_until_parked();
3251 }
3252
3253 let mut errors = Vec::new();
3254 let mut retry_events = Vec::new();
3255 while let Some(event) = events.next().await {
3256 match event {
3257 Ok(ThreadEvent::Retry(retry_status)) => {
3258 retry_events.push(retry_status);
3259 }
3260 Ok(ThreadEvent::Stop(..)) => break,
3261 Err(error) => errors.push(error),
3262 _ => {}
3263 }
3264 }
3265
3266 assert_eq!(
3267 retry_events.len(),
3268 crate::thread::MAX_RETRY_ATTEMPTS as usize
3269 );
3270 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3271 assert_eq!(retry_events[i].attempt, i + 1);
3272 }
3273 assert_eq!(errors.len(), 1);
3274 let error = errors[0]
3275 .downcast_ref::<LanguageModelCompletionError>()
3276 .unwrap();
3277 assert!(matches!(
3278 error,
3279 LanguageModelCompletionError::ServerOverloaded { .. }
3280 ));
3281}
3282
3283/// Filters out the stop events for asserting against in tests
3284fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3285 result_events
3286 .into_iter()
3287 .filter_map(|event| match event.unwrap() {
3288 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3289 _ => None,
3290 })
3291 .collect()
3292}
3293
3294struct ThreadTest {
3295 model: Arc<dyn LanguageModel>,
3296 thread: Entity<Thread>,
3297 project_context: Entity<ProjectContext>,
3298 context_server_store: Entity<ContextServerStore>,
3299 fs: Arc<FakeFs>,
3300}
3301
3302enum TestModel {
3303 Sonnet4,
3304 Fake,
3305}
3306
3307impl TestModel {
3308 fn id(&self) -> LanguageModelId {
3309 match self {
3310 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3311 TestModel::Fake => unreachable!(),
3312 }
3313 }
3314}
3315
3316async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3317 cx.executor().allow_parking();
3318
3319 let fs = FakeFs::new(cx.background_executor.clone());
3320 fs.create_dir(paths::settings_file().parent().unwrap())
3321 .await
3322 .unwrap();
3323 fs.insert_file(
3324 paths::settings_file(),
3325 json!({
3326 "agent": {
3327 "default_profile": "test-profile",
3328 "profiles": {
3329 "test-profile": {
3330 "name": "Test Profile",
3331 "tools": {
3332 EchoTool::name(): true,
3333 DelayTool::name(): true,
3334 WordListTool::name(): true,
3335 ToolRequiringPermission::name(): true,
3336 InfiniteTool::name(): true,
3337 CancellationAwareTool::name(): true,
3338 ThinkingTool::name(): true,
3339 "terminal": true,
3340 }
3341 }
3342 }
3343 }
3344 })
3345 .to_string()
3346 .into_bytes(),
3347 )
3348 .await;
3349
3350 cx.update(|cx| {
3351 settings::init(cx);
3352
3353 match model {
3354 TestModel::Fake => {}
3355 TestModel::Sonnet4 => {
3356 gpui_tokio::init(cx);
3357 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3358 cx.set_http_client(Arc::new(http_client));
3359 let client = Client::production(cx);
3360 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3361 language_model::init(client.clone(), cx);
3362 language_models::init(user_store, client.clone(), cx);
3363 }
3364 };
3365
3366 watch_settings(fs.clone(), cx);
3367 });
3368
3369 let templates = Templates::new();
3370
3371 fs.insert_tree(path!("/test"), json!({})).await;
3372 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3373
3374 let model = cx
3375 .update(|cx| {
3376 if let TestModel::Fake = model {
3377 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3378 } else {
3379 let model_id = model.id();
3380 let models = LanguageModelRegistry::read_global(cx);
3381 let model = models
3382 .available_models(cx)
3383 .find(|model| model.id() == model_id)
3384 .unwrap();
3385
3386 let provider = models.provider(&model.provider_id()).unwrap();
3387 let authenticated = provider.authenticate(cx);
3388
3389 cx.spawn(async move |_cx| {
3390 authenticated.await.unwrap();
3391 model
3392 })
3393 }
3394 })
3395 .await;
3396
3397 let project_context = cx.new(|_cx| ProjectContext::default());
3398 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3399 let context_server_registry =
3400 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3401 let thread = cx.new(|cx| {
3402 Thread::new(
3403 project,
3404 project_context.clone(),
3405 context_server_registry,
3406 templates,
3407 Some(model.clone()),
3408 cx,
3409 )
3410 });
3411 ThreadTest {
3412 model,
3413 thread,
3414 project_context,
3415 context_server_store,
3416 fs,
3417 }
3418}
3419
3420#[cfg(test)]
3421#[ctor::ctor]
3422fn init_logger() {
3423 if std::env::var("RUST_LOG").is_ok() {
3424 env_logger::init();
3425 }
3426}
3427
3428fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3429 let fs = fs.clone();
3430 cx.spawn({
3431 async move |cx| {
3432 let mut new_settings_content_rx = settings::watch_config_file(
3433 cx.background_executor(),
3434 fs,
3435 paths::settings_file().clone(),
3436 );
3437
3438 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3439 cx.update(|cx| {
3440 SettingsStore::update_global(cx, |settings, cx| {
3441 settings.set_user_settings(&new_settings_content, cx)
3442 })
3443 })
3444 .ok();
3445 }
3446 }
3447 })
3448 .detach();
3449}
3450
3451fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3452 completion
3453 .tools
3454 .iter()
3455 .map(|tool| tool.name.clone())
3456 .collect()
3457}
3458
3459fn setup_context_server(
3460 name: &'static str,
3461 tools: Vec<context_server::types::Tool>,
3462 context_server_store: &Entity<ContextServerStore>,
3463 cx: &mut TestAppContext,
3464) -> mpsc::UnboundedReceiver<(
3465 context_server::types::CallToolParams,
3466 oneshot::Sender<context_server::types::CallToolResponse>,
3467)> {
3468 cx.update(|cx| {
3469 let mut settings = ProjectSettings::get_global(cx).clone();
3470 settings.context_servers.insert(
3471 name.into(),
3472 project::project_settings::ContextServerSettings::Stdio {
3473 enabled: true,
3474 command: ContextServerCommand {
3475 path: "somebinary".into(),
3476 args: Vec::new(),
3477 env: None,
3478 timeout: None,
3479 },
3480 },
3481 );
3482 ProjectSettings::override_global(settings, cx);
3483 });
3484
3485 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3486 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3487 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3488 context_server::types::InitializeResponse {
3489 protocol_version: context_server::types::ProtocolVersion(
3490 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3491 ),
3492 server_info: context_server::types::Implementation {
3493 name: name.into(),
3494 version: "1.0.0".to_string(),
3495 },
3496 capabilities: context_server::types::ServerCapabilities {
3497 tools: Some(context_server::types::ToolsCapabilities {
3498 list_changed: Some(true),
3499 }),
3500 ..Default::default()
3501 },
3502 meta: None,
3503 }
3504 })
3505 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3506 let tools = tools.clone();
3507 async move {
3508 context_server::types::ListToolsResponse {
3509 tools,
3510 next_cursor: None,
3511 meta: None,
3512 }
3513 }
3514 })
3515 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3516 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3517 async move {
3518 let (response_tx, response_rx) = oneshot::channel();
3519 mcp_tool_calls_tx
3520 .unbounded_send((params, response_tx))
3521 .unwrap();
3522 response_rx.await.unwrap()
3523 }
3524 });
3525 context_server_store.update(cx, |store, cx| {
3526 store.start_server(
3527 Arc::new(ContextServer::new(
3528 ContextServerId(name.into()),
3529 Arc::new(fake_transport),
3530 )),
3531 cx,
3532 );
3533 });
3534 cx.run_until_parked();
3535 mcp_tool_calls_rx
3536}
3537
3538#[gpui::test]
3539async fn test_tokens_before_message(cx: &mut TestAppContext) {
3540 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3541 let fake_model = model.as_fake();
3542
3543 // First message
3544 let message_1_id = UserMessageId::new();
3545 thread
3546 .update(cx, |thread, cx| {
3547 thread.send(message_1_id.clone(), ["First message"], cx)
3548 })
3549 .unwrap();
3550 cx.run_until_parked();
3551
3552 // Before any response, tokens_before_message should return None for first message
3553 thread.read_with(cx, |thread, _| {
3554 assert_eq!(
3555 thread.tokens_before_message(&message_1_id),
3556 None,
3557 "First message should have no tokens before it"
3558 );
3559 });
3560
3561 // Complete first message with usage
3562 fake_model.send_last_completion_stream_text_chunk("Response 1");
3563 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3564 language_model::TokenUsage {
3565 input_tokens: 100,
3566 output_tokens: 50,
3567 cache_creation_input_tokens: 0,
3568 cache_read_input_tokens: 0,
3569 },
3570 ));
3571 fake_model.end_last_completion_stream();
3572 cx.run_until_parked();
3573
3574 // First message still has no tokens before it
3575 thread.read_with(cx, |thread, _| {
3576 assert_eq!(
3577 thread.tokens_before_message(&message_1_id),
3578 None,
3579 "First message should still have no tokens before it after response"
3580 );
3581 });
3582
3583 // Second message
3584 let message_2_id = UserMessageId::new();
3585 thread
3586 .update(cx, |thread, cx| {
3587 thread.send(message_2_id.clone(), ["Second message"], cx)
3588 })
3589 .unwrap();
3590 cx.run_until_parked();
3591
3592 // Second message should have first message's input tokens before it
3593 thread.read_with(cx, |thread, _| {
3594 assert_eq!(
3595 thread.tokens_before_message(&message_2_id),
3596 Some(100),
3597 "Second message should have 100 tokens before it (from first request)"
3598 );
3599 });
3600
3601 // Complete second message
3602 fake_model.send_last_completion_stream_text_chunk("Response 2");
3603 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3604 language_model::TokenUsage {
3605 input_tokens: 250, // Total for this request (includes previous context)
3606 output_tokens: 75,
3607 cache_creation_input_tokens: 0,
3608 cache_read_input_tokens: 0,
3609 },
3610 ));
3611 fake_model.end_last_completion_stream();
3612 cx.run_until_parked();
3613
3614 // Third message
3615 let message_3_id = UserMessageId::new();
3616 thread
3617 .update(cx, |thread, cx| {
3618 thread.send(message_3_id.clone(), ["Third message"], cx)
3619 })
3620 .unwrap();
3621 cx.run_until_parked();
3622
3623 // Third message should have second message's input tokens (250) before it
3624 thread.read_with(cx, |thread, _| {
3625 assert_eq!(
3626 thread.tokens_before_message(&message_3_id),
3627 Some(250),
3628 "Third message should have 250 tokens before it (from second request)"
3629 );
3630 // Second message should still have 100
3631 assert_eq!(
3632 thread.tokens_before_message(&message_2_id),
3633 Some(100),
3634 "Second message should still have 100 tokens before it"
3635 );
3636 // First message still has none
3637 assert_eq!(
3638 thread.tokens_before_message(&message_1_id),
3639 None,
3640 "First message should still have no tokens before it"
3641 );
3642 });
3643}
3644
3645#[gpui::test]
3646async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3647 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3648 let fake_model = model.as_fake();
3649
3650 // Set up three messages with responses
3651 let message_1_id = UserMessageId::new();
3652 thread
3653 .update(cx, |thread, cx| {
3654 thread.send(message_1_id.clone(), ["Message 1"], cx)
3655 })
3656 .unwrap();
3657 cx.run_until_parked();
3658 fake_model.send_last_completion_stream_text_chunk("Response 1");
3659 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3660 language_model::TokenUsage {
3661 input_tokens: 100,
3662 output_tokens: 50,
3663 cache_creation_input_tokens: 0,
3664 cache_read_input_tokens: 0,
3665 },
3666 ));
3667 fake_model.end_last_completion_stream();
3668 cx.run_until_parked();
3669
3670 let message_2_id = UserMessageId::new();
3671 thread
3672 .update(cx, |thread, cx| {
3673 thread.send(message_2_id.clone(), ["Message 2"], cx)
3674 })
3675 .unwrap();
3676 cx.run_until_parked();
3677 fake_model.send_last_completion_stream_text_chunk("Response 2");
3678 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3679 language_model::TokenUsage {
3680 input_tokens: 250,
3681 output_tokens: 75,
3682 cache_creation_input_tokens: 0,
3683 cache_read_input_tokens: 0,
3684 },
3685 ));
3686 fake_model.end_last_completion_stream();
3687 cx.run_until_parked();
3688
3689 // Verify initial state
3690 thread.read_with(cx, |thread, _| {
3691 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3692 });
3693
3694 // Truncate at message 2 (removes message 2 and everything after)
3695 thread
3696 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3697 .unwrap();
3698 cx.run_until_parked();
3699
3700 // After truncation, message_2_id no longer exists, so lookup should return None
3701 thread.read_with(cx, |thread, _| {
3702 assert_eq!(
3703 thread.tokens_before_message(&message_2_id),
3704 None,
3705 "After truncation, message 2 no longer exists"
3706 );
3707 // Message 1 still exists but has no tokens before it
3708 assert_eq!(
3709 thread.tokens_before_message(&message_1_id),
3710 None,
3711 "First message still has no tokens before it"
3712 );
3713 });
3714}
3715
3716#[gpui::test]
3717async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3718 init_test(cx);
3719
3720 let fs = FakeFs::new(cx.executor());
3721 fs.insert_tree("/root", json!({})).await;
3722 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3723
3724 // Test 1: Deny rule blocks command
3725 {
3726 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3727 let environment = Rc::new(FakeThreadEnvironment {
3728 handle: handle.clone(),
3729 });
3730
3731 cx.update(|cx| {
3732 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3733 settings.tool_permissions.tools.insert(
3734 "terminal".into(),
3735 agent_settings::ToolRules {
3736 default_mode: settings::ToolPermissionMode::Confirm,
3737 always_allow: vec![],
3738 always_deny: vec![
3739 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3740 ],
3741 always_confirm: vec![],
3742 invalid_patterns: vec![],
3743 },
3744 );
3745 agent_settings::AgentSettings::override_global(settings, cx);
3746 });
3747
3748 #[allow(clippy::arc_with_non_send_sync)]
3749 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3750 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3751
3752 let task = cx.update(|cx| {
3753 tool.run(
3754 crate::TerminalToolInput {
3755 command: "rm -rf /".to_string(),
3756 cd: ".".to_string(),
3757 timeout_ms: None,
3758 },
3759 event_stream,
3760 cx,
3761 )
3762 });
3763
3764 let result = task.await;
3765 assert!(
3766 result.is_err(),
3767 "expected command to be blocked by deny rule"
3768 );
3769 assert!(
3770 result.unwrap_err().to_string().contains("blocked"),
3771 "error should mention the command was blocked"
3772 );
3773 }
3774
3775 // Test 2: Allow rule skips confirmation (and overrides default_mode: Deny)
3776 {
3777 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3778 let environment = Rc::new(FakeThreadEnvironment {
3779 handle: handle.clone(),
3780 });
3781
3782 cx.update(|cx| {
3783 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3784 settings.always_allow_tool_actions = false;
3785 settings.tool_permissions.tools.insert(
3786 "terminal".into(),
3787 agent_settings::ToolRules {
3788 default_mode: settings::ToolPermissionMode::Deny,
3789 always_allow: vec![
3790 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
3791 ],
3792 always_deny: vec![],
3793 always_confirm: vec![],
3794 invalid_patterns: vec![],
3795 },
3796 );
3797 agent_settings::AgentSettings::override_global(settings, cx);
3798 });
3799
3800 #[allow(clippy::arc_with_non_send_sync)]
3801 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3802 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3803
3804 let task = cx.update(|cx| {
3805 tool.run(
3806 crate::TerminalToolInput {
3807 command: "echo hello".to_string(),
3808 cd: ".".to_string(),
3809 timeout_ms: None,
3810 },
3811 event_stream,
3812 cx,
3813 )
3814 });
3815
3816 let update = rx.expect_update_fields().await;
3817 assert!(
3818 update.content.iter().any(|blocks| {
3819 blocks
3820 .iter()
3821 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
3822 }),
3823 "expected terminal content (allow rule should skip confirmation and override default deny)"
3824 );
3825
3826 let result = task.await;
3827 assert!(
3828 result.is_ok(),
3829 "expected command to succeed without confirmation"
3830 );
3831 }
3832
3833 // Test 3: Confirm rule forces confirmation even with always_allow_tool_actions=true
3834 {
3835 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3836 let environment = Rc::new(FakeThreadEnvironment {
3837 handle: handle.clone(),
3838 });
3839
3840 cx.update(|cx| {
3841 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3842 settings.always_allow_tool_actions = true;
3843 settings.tool_permissions.tools.insert(
3844 "terminal".into(),
3845 agent_settings::ToolRules {
3846 default_mode: settings::ToolPermissionMode::Allow,
3847 always_allow: vec![],
3848 always_deny: vec![],
3849 always_confirm: vec![
3850 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
3851 ],
3852 invalid_patterns: vec![],
3853 },
3854 );
3855 agent_settings::AgentSettings::override_global(settings, cx);
3856 });
3857
3858 #[allow(clippy::arc_with_non_send_sync)]
3859 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3860 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3861
3862 let _task = cx.update(|cx| {
3863 tool.run(
3864 crate::TerminalToolInput {
3865 command: "sudo rm file".to_string(),
3866 cd: ".".to_string(),
3867 timeout_ms: None,
3868 },
3869 event_stream,
3870 cx,
3871 )
3872 });
3873
3874 let auth = rx.expect_authorization().await;
3875 assert!(
3876 auth.tool_call.fields.title.is_some(),
3877 "expected authorization request for sudo command despite always_allow_tool_actions=true"
3878 );
3879 }
3880
3881 // Test 4: default_mode: Deny blocks commands when no pattern matches
3882 {
3883 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3884 let environment = Rc::new(FakeThreadEnvironment {
3885 handle: handle.clone(),
3886 });
3887
3888 cx.update(|cx| {
3889 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3890 settings.always_allow_tool_actions = true;
3891 settings.tool_permissions.tools.insert(
3892 "terminal".into(),
3893 agent_settings::ToolRules {
3894 default_mode: settings::ToolPermissionMode::Deny,
3895 always_allow: vec![],
3896 always_deny: vec![],
3897 always_confirm: vec![],
3898 invalid_patterns: vec![],
3899 },
3900 );
3901 agent_settings::AgentSettings::override_global(settings, cx);
3902 });
3903
3904 #[allow(clippy::arc_with_non_send_sync)]
3905 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3906 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3907
3908 let task = cx.update(|cx| {
3909 tool.run(
3910 crate::TerminalToolInput {
3911 command: "echo hello".to_string(),
3912 cd: ".".to_string(),
3913 timeout_ms: None,
3914 },
3915 event_stream,
3916 cx,
3917 )
3918 });
3919
3920 let result = task.await;
3921 assert!(
3922 result.is_err(),
3923 "expected command to be blocked by default_mode: Deny"
3924 );
3925 assert!(
3926 result.unwrap_err().to_string().contains("disabled"),
3927 "error should mention the tool is disabled"
3928 );
3929 }
3930}
3931
3932#[gpui::test]
3933async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
3934 init_test(cx);
3935
3936 cx.update(|cx| {
3937 cx.update_flags(true, vec!["subagents".to_string()]);
3938 });
3939
3940 let fs = FakeFs::new(cx.executor());
3941 fs.insert_tree(path!("/test"), json!({})).await;
3942 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3943 let project_context = cx.new(|_cx| ProjectContext::default());
3944 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3945 let context_server_registry =
3946 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3947 let model = Arc::new(FakeLanguageModel::default());
3948
3949 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3950 let environment = Rc::new(FakeThreadEnvironment { handle });
3951
3952 let thread = cx.new(|cx| {
3953 let mut thread = Thread::new(
3954 project.clone(),
3955 project_context,
3956 context_server_registry,
3957 Templates::new(),
3958 Some(model),
3959 cx,
3960 );
3961 thread.add_default_tools(environment, cx);
3962 thread
3963 });
3964
3965 thread.read_with(cx, |thread, _| {
3966 assert!(
3967 thread.has_registered_tool("subagent"),
3968 "subagent tool should be present when feature flag is enabled"
3969 );
3970 });
3971}
3972
3973#[gpui::test]
3974async fn test_subagent_thread_inherits_parent_model(cx: &mut TestAppContext) {
3975 init_test(cx);
3976
3977 cx.update(|cx| {
3978 cx.update_flags(true, vec!["subagents".to_string()]);
3979 });
3980
3981 let fs = FakeFs::new(cx.executor());
3982 fs.insert_tree(path!("/test"), json!({})).await;
3983 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3984 let project_context = cx.new(|_cx| ProjectContext::default());
3985 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3986 let context_server_registry =
3987 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3988 let model = Arc::new(FakeLanguageModel::default());
3989
3990 let subagent_context = SubagentContext {
3991 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3992 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3993 depth: 1,
3994 summary_prompt: "Summarize".to_string(),
3995 context_low_prompt: "Context low".to_string(),
3996 };
3997
3998 let subagent = cx.new(|cx| {
3999 Thread::new_subagent(
4000 project.clone(),
4001 project_context,
4002 context_server_registry,
4003 Templates::new(),
4004 model.clone(),
4005 subagent_context,
4006 std::collections::BTreeMap::new(),
4007 cx,
4008 )
4009 });
4010
4011 subagent.read_with(cx, |thread, _| {
4012 assert!(thread.is_subagent());
4013 assert_eq!(thread.depth(), 1);
4014 assert!(thread.model().is_some());
4015 });
4016}
4017
4018#[gpui::test]
4019async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4020 init_test(cx);
4021
4022 cx.update(|cx| {
4023 cx.update_flags(true, vec!["subagents".to_string()]);
4024 });
4025
4026 let fs = FakeFs::new(cx.executor());
4027 fs.insert_tree(path!("/test"), json!({})).await;
4028 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4029 let project_context = cx.new(|_cx| ProjectContext::default());
4030 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4031 let context_server_registry =
4032 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4033 let model = Arc::new(FakeLanguageModel::default());
4034
4035 let subagent_context = SubagentContext {
4036 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4037 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4038 depth: MAX_SUBAGENT_DEPTH,
4039 summary_prompt: "Summarize".to_string(),
4040 context_low_prompt: "Context low".to_string(),
4041 };
4042
4043 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
4044 let environment = Rc::new(FakeThreadEnvironment { handle });
4045
4046 let deep_subagent = cx.new(|cx| {
4047 let mut thread = Thread::new_subagent(
4048 project.clone(),
4049 project_context,
4050 context_server_registry,
4051 Templates::new(),
4052 model.clone(),
4053 subagent_context,
4054 std::collections::BTreeMap::new(),
4055 cx,
4056 );
4057 thread.add_default_tools(environment, cx);
4058 thread
4059 });
4060
4061 deep_subagent.read_with(cx, |thread, _| {
4062 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
4063 assert!(
4064 !thread.has_registered_tool("subagent"),
4065 "subagent tool should not be present at max depth"
4066 );
4067 });
4068}
4069
4070#[gpui::test]
4071async fn test_subagent_receives_task_prompt(cx: &mut TestAppContext) {
4072 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4073 let fake_model = model.as_fake();
4074
4075 cx.update(|cx| {
4076 cx.update_flags(true, vec!["subagents".to_string()]);
4077 });
4078
4079 let subagent_context = SubagentContext {
4080 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4081 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4082 depth: 1,
4083 summary_prompt: "Summarize your work".to_string(),
4084 context_low_prompt: "Context low, wrap up".to_string(),
4085 };
4086
4087 let project = thread.read_with(cx, |t, _| t.project.clone());
4088 let project_context = cx.new(|_cx| ProjectContext::default());
4089 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4090 let context_server_registry =
4091 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4092
4093 let subagent = cx.new(|cx| {
4094 Thread::new_subagent(
4095 project.clone(),
4096 project_context,
4097 context_server_registry,
4098 Templates::new(),
4099 model.clone(),
4100 subagent_context,
4101 std::collections::BTreeMap::new(),
4102 cx,
4103 )
4104 });
4105
4106 let task_prompt = "Find all TODO comments in the codebase";
4107 subagent
4108 .update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))
4109 .unwrap();
4110 cx.run_until_parked();
4111
4112 let pending = fake_model.pending_completions();
4113 assert_eq!(pending.len(), 1, "should have one pending completion");
4114
4115 let messages = &pending[0].messages;
4116 let user_messages: Vec<_> = messages
4117 .iter()
4118 .filter(|m| m.role == language_model::Role::User)
4119 .collect();
4120 assert_eq!(user_messages.len(), 1, "should have one user message");
4121
4122 let content = &user_messages[0].content[0];
4123 assert!(
4124 content.to_str().unwrap().contains("TODO"),
4125 "task prompt should be in user message"
4126 );
4127}
4128
4129#[gpui::test]
4130async fn test_subagent_returns_summary_on_completion(cx: &mut TestAppContext) {
4131 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4132 let fake_model = model.as_fake();
4133
4134 cx.update(|cx| {
4135 cx.update_flags(true, vec!["subagents".to_string()]);
4136 });
4137
4138 let subagent_context = SubagentContext {
4139 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4140 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4141 depth: 1,
4142 summary_prompt: "Please summarize what you found".to_string(),
4143 context_low_prompt: "Context low, wrap up".to_string(),
4144 };
4145
4146 let project = thread.read_with(cx, |t, _| t.project.clone());
4147 let project_context = cx.new(|_cx| ProjectContext::default());
4148 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4149 let context_server_registry =
4150 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4151
4152 let subagent = cx.new(|cx| {
4153 Thread::new_subagent(
4154 project.clone(),
4155 project_context,
4156 context_server_registry,
4157 Templates::new(),
4158 model.clone(),
4159 subagent_context,
4160 std::collections::BTreeMap::new(),
4161 cx,
4162 )
4163 });
4164
4165 subagent
4166 .update(cx, |thread, cx| {
4167 thread.submit_user_message("Do some work", cx)
4168 })
4169 .unwrap();
4170 cx.run_until_parked();
4171
4172 fake_model.send_last_completion_stream_text_chunk("I did the work");
4173 fake_model
4174 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4175 fake_model.end_last_completion_stream();
4176 cx.run_until_parked();
4177
4178 subagent
4179 .update(cx, |thread, cx| thread.request_final_summary(cx))
4180 .unwrap();
4181 cx.run_until_parked();
4182
4183 let pending = fake_model.pending_completions();
4184 assert!(
4185 !pending.is_empty(),
4186 "should have pending completion for summary"
4187 );
4188
4189 let messages = &pending.last().unwrap().messages;
4190 let user_messages: Vec<_> = messages
4191 .iter()
4192 .filter(|m| m.role == language_model::Role::User)
4193 .collect();
4194
4195 let last_user = user_messages.last().unwrap();
4196 assert!(
4197 last_user.content[0].to_str().unwrap().contains("summarize"),
4198 "summary prompt should be sent"
4199 );
4200}
4201
4202#[gpui::test]
4203async fn test_allowed_tools_restricts_subagent_capabilities(cx: &mut TestAppContext) {
4204 init_test(cx);
4205
4206 cx.update(|cx| {
4207 cx.update_flags(true, vec!["subagents".to_string()]);
4208 });
4209
4210 let fs = FakeFs::new(cx.executor());
4211 fs.insert_tree(path!("/test"), json!({})).await;
4212 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4213 let project_context = cx.new(|_cx| ProjectContext::default());
4214 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4215 let context_server_registry =
4216 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4217 let model = Arc::new(FakeLanguageModel::default());
4218
4219 let subagent_context = SubagentContext {
4220 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4221 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4222 depth: 1,
4223 summary_prompt: "Summarize".to_string(),
4224 context_low_prompt: "Context low".to_string(),
4225 };
4226
4227 let subagent = cx.new(|cx| {
4228 let mut thread = Thread::new_subagent(
4229 project.clone(),
4230 project_context,
4231 context_server_registry,
4232 Templates::new(),
4233 model.clone(),
4234 subagent_context,
4235 std::collections::BTreeMap::new(),
4236 cx,
4237 );
4238 thread.add_tool(EchoTool);
4239 thread.add_tool(DelayTool);
4240 thread.add_tool(WordListTool);
4241 thread
4242 });
4243
4244 subagent.read_with(cx, |thread, _| {
4245 assert!(thread.has_registered_tool("echo"));
4246 assert!(thread.has_registered_tool("delay"));
4247 assert!(thread.has_registered_tool("word_list"));
4248 });
4249
4250 let allowed: collections::HashSet<gpui::SharedString> =
4251 vec!["echo".into()].into_iter().collect();
4252
4253 subagent.update(cx, |thread, _cx| {
4254 thread.restrict_tools(&allowed);
4255 });
4256
4257 subagent.read_with(cx, |thread, _| {
4258 assert!(
4259 thread.has_registered_tool("echo"),
4260 "echo should still be available"
4261 );
4262 assert!(
4263 !thread.has_registered_tool("delay"),
4264 "delay should be removed"
4265 );
4266 assert!(
4267 !thread.has_registered_tool("word_list"),
4268 "word_list should be removed"
4269 );
4270 });
4271}
4272
4273#[gpui::test]
4274async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4275 init_test(cx);
4276
4277 cx.update(|cx| {
4278 cx.update_flags(true, vec!["subagents".to_string()]);
4279 });
4280
4281 let fs = FakeFs::new(cx.executor());
4282 fs.insert_tree(path!("/test"), json!({})).await;
4283 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4284 let project_context = cx.new(|_cx| ProjectContext::default());
4285 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4286 let context_server_registry =
4287 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4288 let model = Arc::new(FakeLanguageModel::default());
4289
4290 let parent = cx.new(|cx| {
4291 Thread::new(
4292 project.clone(),
4293 project_context.clone(),
4294 context_server_registry.clone(),
4295 Templates::new(),
4296 Some(model.clone()),
4297 cx,
4298 )
4299 });
4300
4301 let subagent_context = SubagentContext {
4302 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4303 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4304 depth: 1,
4305 summary_prompt: "Summarize".to_string(),
4306 context_low_prompt: "Context low".to_string(),
4307 };
4308
4309 let subagent = cx.new(|cx| {
4310 Thread::new_subagent(
4311 project.clone(),
4312 project_context.clone(),
4313 context_server_registry.clone(),
4314 Templates::new(),
4315 model.clone(),
4316 subagent_context,
4317 std::collections::BTreeMap::new(),
4318 cx,
4319 )
4320 });
4321
4322 parent.update(cx, |thread, _cx| {
4323 thread.register_running_subagent(subagent.downgrade());
4324 });
4325
4326 subagent
4327 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4328 .unwrap();
4329 cx.run_until_parked();
4330
4331 subagent.read_with(cx, |thread, _| {
4332 assert!(!thread.is_turn_complete(), "subagent should be running");
4333 });
4334
4335 parent.update(cx, |thread, cx| {
4336 thread.cancel(cx).detach();
4337 });
4338
4339 subagent.read_with(cx, |thread, _| {
4340 assert!(
4341 thread.is_turn_complete(),
4342 "subagent should be cancelled when parent cancels"
4343 );
4344 });
4345}
4346
4347#[gpui::test]
4348async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4349 // This test verifies that the subagent tool properly handles user cancellation
4350 // via `event_stream.cancelled_by_user()` and stops all running subagents.
4351 init_test(cx);
4352 always_allow_tools(cx);
4353
4354 cx.update(|cx| {
4355 cx.update_flags(true, vec!["subagents".to_string()]);
4356 });
4357
4358 let fs = FakeFs::new(cx.executor());
4359 fs.insert_tree(path!("/test"), json!({})).await;
4360 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4361 let project_context = cx.new(|_cx| ProjectContext::default());
4362 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4363 let context_server_registry =
4364 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4365 let model = Arc::new(FakeLanguageModel::default());
4366
4367 let parent = cx.new(|cx| {
4368 Thread::new(
4369 project.clone(),
4370 project_context.clone(),
4371 context_server_registry.clone(),
4372 Templates::new(),
4373 Some(model.clone()),
4374 cx,
4375 )
4376 });
4377
4378 let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
4379 std::collections::BTreeMap::new();
4380
4381 #[allow(clippy::arc_with_non_send_sync)]
4382 let tool = Arc::new(SubagentTool::new(
4383 parent.downgrade(),
4384 project.clone(),
4385 project_context,
4386 context_server_registry,
4387 Templates::new(),
4388 0,
4389 parent_tools,
4390 ));
4391
4392 let (event_stream, _rx, mut cancellation_tx) =
4393 crate::ToolCallEventStream::test_with_cancellation();
4394
4395 // Start the subagent tool
4396 let task = cx.update(|cx| {
4397 tool.run(
4398 SubagentToolInput {
4399 subagents: vec![crate::SubagentConfig {
4400 label: "Long running task".to_string(),
4401 task_prompt: "Do a very long task that takes forever".to_string(),
4402 summary_prompt: "Summarize".to_string(),
4403 context_low_prompt: "Context low".to_string(),
4404 timeout_ms: None,
4405 allowed_tools: None,
4406 }],
4407 },
4408 event_stream.clone(),
4409 cx,
4410 )
4411 });
4412
4413 cx.run_until_parked();
4414
4415 // Signal cancellation via the event stream
4416 crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4417
4418 // The task should complete promptly with a cancellation error
4419 let timeout = cx.background_executor.timer(Duration::from_secs(5));
4420 let result = futures::select! {
4421 result = task.fuse() => result,
4422 _ = timeout.fuse() => {
4423 panic!("subagent tool did not respond to cancellation within timeout");
4424 }
4425 };
4426
4427 // Verify we got a cancellation error
4428 let err = result.unwrap_err();
4429 assert!(
4430 err.to_string().contains("cancelled by user"),
4431 "expected cancellation error, got: {}",
4432 err
4433 );
4434}
4435
4436#[gpui::test]
4437async fn test_subagent_model_error_returned_as_tool_error(cx: &mut TestAppContext) {
4438 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4439 let fake_model = model.as_fake();
4440
4441 cx.update(|cx| {
4442 cx.update_flags(true, vec!["subagents".to_string()]);
4443 });
4444
4445 let subagent_context = SubagentContext {
4446 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4447 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4448 depth: 1,
4449 summary_prompt: "Summarize".to_string(),
4450 context_low_prompt: "Context low".to_string(),
4451 };
4452
4453 let project = thread.read_with(cx, |t, _| t.project.clone());
4454 let project_context = cx.new(|_cx| ProjectContext::default());
4455 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4456 let context_server_registry =
4457 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4458
4459 let subagent = cx.new(|cx| {
4460 Thread::new_subagent(
4461 project.clone(),
4462 project_context,
4463 context_server_registry,
4464 Templates::new(),
4465 model.clone(),
4466 subagent_context,
4467 std::collections::BTreeMap::new(),
4468 cx,
4469 )
4470 });
4471
4472 subagent
4473 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4474 .unwrap();
4475 cx.run_until_parked();
4476
4477 subagent.read_with(cx, |thread, _| {
4478 assert!(!thread.is_turn_complete(), "turn should be in progress");
4479 });
4480
4481 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::NoApiKey {
4482 provider: LanguageModelProviderName::from("Fake".to_string()),
4483 });
4484 fake_model.end_last_completion_stream();
4485 cx.run_until_parked();
4486
4487 subagent.read_with(cx, |thread, _| {
4488 assert!(
4489 thread.is_turn_complete(),
4490 "turn should be complete after non-retryable error"
4491 );
4492 });
4493}
4494
4495#[gpui::test]
4496async fn test_subagent_timeout_triggers_early_summary(cx: &mut TestAppContext) {
4497 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4498 let fake_model = model.as_fake();
4499
4500 cx.update(|cx| {
4501 cx.update_flags(true, vec!["subagents".to_string()]);
4502 });
4503
4504 let subagent_context = SubagentContext {
4505 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4506 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4507 depth: 1,
4508 summary_prompt: "Summarize your work".to_string(),
4509 context_low_prompt: "Context low, stop and summarize".to_string(),
4510 };
4511
4512 let project = thread.read_with(cx, |t, _| t.project.clone());
4513 let project_context = cx.new(|_cx| ProjectContext::default());
4514 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4515 let context_server_registry =
4516 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4517
4518 let subagent = cx.new(|cx| {
4519 Thread::new_subagent(
4520 project.clone(),
4521 project_context.clone(),
4522 context_server_registry.clone(),
4523 Templates::new(),
4524 model.clone(),
4525 subagent_context.clone(),
4526 std::collections::BTreeMap::new(),
4527 cx,
4528 )
4529 });
4530
4531 subagent.update(cx, |thread, _| {
4532 thread.add_tool(EchoTool);
4533 });
4534
4535 subagent
4536 .update(cx, |thread, cx| {
4537 thread.submit_user_message("Do some work", cx)
4538 })
4539 .unwrap();
4540 cx.run_until_parked();
4541
4542 fake_model.send_last_completion_stream_text_chunk("Working on it...");
4543 fake_model
4544 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4545 fake_model.end_last_completion_stream();
4546 cx.run_until_parked();
4547
4548 let interrupt_result = subagent.update(cx, |thread, cx| thread.interrupt_for_summary(cx));
4549 assert!(
4550 interrupt_result.is_ok(),
4551 "interrupt_for_summary should succeed"
4552 );
4553
4554 cx.run_until_parked();
4555
4556 let pending = fake_model.pending_completions();
4557 assert!(
4558 !pending.is_empty(),
4559 "should have pending completion for interrupted summary"
4560 );
4561
4562 let messages = &pending.last().unwrap().messages;
4563 let user_messages: Vec<_> = messages
4564 .iter()
4565 .filter(|m| m.role == language_model::Role::User)
4566 .collect();
4567
4568 let last_user = user_messages.last().unwrap();
4569 let content_str = last_user.content[0].to_str().unwrap();
4570 assert!(
4571 content_str.contains("Context low") || content_str.contains("stop and summarize"),
4572 "context_low_prompt should be sent when interrupting: got {:?}",
4573 content_str
4574 );
4575}
4576
4577#[gpui::test]
4578async fn test_context_low_check_returns_true_when_usage_high(cx: &mut TestAppContext) {
4579 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4580 let fake_model = model.as_fake();
4581
4582 cx.update(|cx| {
4583 cx.update_flags(true, vec!["subagents".to_string()]);
4584 });
4585
4586 let subagent_context = SubagentContext {
4587 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4588 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4589 depth: 1,
4590 summary_prompt: "Summarize".to_string(),
4591 context_low_prompt: "Context low".to_string(),
4592 };
4593
4594 let project = thread.read_with(cx, |t, _| t.project.clone());
4595 let project_context = cx.new(|_cx| ProjectContext::default());
4596 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4597 let context_server_registry =
4598 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4599
4600 let subagent = cx.new(|cx| {
4601 Thread::new_subagent(
4602 project.clone(),
4603 project_context,
4604 context_server_registry,
4605 Templates::new(),
4606 model.clone(),
4607 subagent_context,
4608 std::collections::BTreeMap::new(),
4609 cx,
4610 )
4611 });
4612
4613 subagent
4614 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4615 .unwrap();
4616 cx.run_until_parked();
4617
4618 let max_tokens = model.max_token_count();
4619 let high_usage = language_model::TokenUsage {
4620 input_tokens: (max_tokens as f64 * 0.80) as u64,
4621 output_tokens: 0,
4622 cache_creation_input_tokens: 0,
4623 cache_read_input_tokens: 0,
4624 };
4625
4626 fake_model
4627 .send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(high_usage));
4628 fake_model.send_last_completion_stream_text_chunk("Working...");
4629 fake_model
4630 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4631 fake_model.end_last_completion_stream();
4632 cx.run_until_parked();
4633
4634 let usage = subagent.read_with(cx, |thread, _| thread.latest_token_usage());
4635 assert!(usage.is_some(), "should have token usage after completion");
4636
4637 let usage = usage.unwrap();
4638 let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
4639 assert!(
4640 remaining_ratio <= 0.25,
4641 "remaining ratio should be at or below 25% (got {}%), indicating context is low",
4642 remaining_ratio * 100.0
4643 );
4644}
4645
4646#[gpui::test]
4647async fn test_allowed_tools_rejects_unknown_tool(cx: &mut TestAppContext) {
4648 init_test(cx);
4649
4650 cx.update(|cx| {
4651 cx.update_flags(true, vec!["subagents".to_string()]);
4652 });
4653
4654 let fs = FakeFs::new(cx.executor());
4655 fs.insert_tree(path!("/test"), json!({})).await;
4656 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4657 let project_context = cx.new(|_cx| ProjectContext::default());
4658 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4659 let context_server_registry =
4660 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4661 let model = Arc::new(FakeLanguageModel::default());
4662
4663 let parent = cx.new(|cx| {
4664 let mut thread = Thread::new(
4665 project.clone(),
4666 project_context.clone(),
4667 context_server_registry.clone(),
4668 Templates::new(),
4669 Some(model.clone()),
4670 cx,
4671 );
4672 thread.add_tool(EchoTool);
4673 thread
4674 });
4675
4676 let mut parent_tools: std::collections::BTreeMap<
4677 gpui::SharedString,
4678 Arc<dyn crate::AnyAgentTool>,
4679 > = std::collections::BTreeMap::new();
4680 parent_tools.insert("echo".into(), EchoTool.erase());
4681
4682 #[allow(clippy::arc_with_non_send_sync)]
4683 let tool = Arc::new(SubagentTool::new(
4684 parent.downgrade(),
4685 project,
4686 project_context,
4687 context_server_registry,
4688 Templates::new(),
4689 0,
4690 parent_tools,
4691 ));
4692
4693 let subagent_configs = vec![crate::SubagentConfig {
4694 label: "Test".to_string(),
4695 task_prompt: "Do something".to_string(),
4696 summary_prompt: "Summarize".to_string(),
4697 context_low_prompt: "Context low".to_string(),
4698 timeout_ms: None,
4699 allowed_tools: Some(vec!["nonexistent_tool".to_string()]),
4700 }];
4701 let result = tool.validate_subagents(&subagent_configs);
4702 assert!(result.is_err(), "should reject unknown tool");
4703 let err_msg = result.unwrap_err().to_string();
4704 assert!(
4705 err_msg.contains("nonexistent_tool"),
4706 "error should mention the invalid tool name: {}",
4707 err_msg
4708 );
4709 assert!(
4710 err_msg.contains("do not exist"),
4711 "error should explain the tool does not exist: {}",
4712 err_msg
4713 );
4714}
4715
4716#[gpui::test]
4717async fn test_subagent_empty_response_handled(cx: &mut TestAppContext) {
4718 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4719 let fake_model = model.as_fake();
4720
4721 cx.update(|cx| {
4722 cx.update_flags(true, vec!["subagents".to_string()]);
4723 });
4724
4725 let subagent_context = SubagentContext {
4726 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4727 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4728 depth: 1,
4729 summary_prompt: "Summarize".to_string(),
4730 context_low_prompt: "Context low".to_string(),
4731 };
4732
4733 let project = thread.read_with(cx, |t, _| t.project.clone());
4734 let project_context = cx.new(|_cx| ProjectContext::default());
4735 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4736 let context_server_registry =
4737 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4738
4739 let subagent = cx.new(|cx| {
4740 Thread::new_subagent(
4741 project.clone(),
4742 project_context,
4743 context_server_registry,
4744 Templates::new(),
4745 model.clone(),
4746 subagent_context,
4747 std::collections::BTreeMap::new(),
4748 cx,
4749 )
4750 });
4751
4752 subagent
4753 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4754 .unwrap();
4755 cx.run_until_parked();
4756
4757 fake_model
4758 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4759 fake_model.end_last_completion_stream();
4760 cx.run_until_parked();
4761
4762 subagent.read_with(cx, |thread, _| {
4763 assert!(
4764 thread.is_turn_complete(),
4765 "turn should complete even with empty response"
4766 );
4767 });
4768}
4769
4770#[gpui::test]
4771async fn test_nested_subagent_at_depth_2_succeeds(cx: &mut TestAppContext) {
4772 init_test(cx);
4773
4774 cx.update(|cx| {
4775 cx.update_flags(true, vec!["subagents".to_string()]);
4776 });
4777
4778 let fs = FakeFs::new(cx.executor());
4779 fs.insert_tree(path!("/test"), json!({})).await;
4780 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4781 let project_context = cx.new(|_cx| ProjectContext::default());
4782 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4783 let context_server_registry =
4784 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4785 let model = Arc::new(FakeLanguageModel::default());
4786
4787 let depth_1_context = SubagentContext {
4788 parent_thread_id: agent_client_protocol::SessionId::new("root-id"),
4789 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-1"),
4790 depth: 1,
4791 summary_prompt: "Summarize".to_string(),
4792 context_low_prompt: "Context low".to_string(),
4793 };
4794
4795 let depth_1_subagent = cx.new(|cx| {
4796 Thread::new_subagent(
4797 project.clone(),
4798 project_context.clone(),
4799 context_server_registry.clone(),
4800 Templates::new(),
4801 model.clone(),
4802 depth_1_context,
4803 std::collections::BTreeMap::new(),
4804 cx,
4805 )
4806 });
4807
4808 depth_1_subagent.read_with(cx, |thread, _| {
4809 assert_eq!(thread.depth(), 1);
4810 assert!(thread.is_subagent());
4811 });
4812
4813 let depth_2_context = SubagentContext {
4814 parent_thread_id: agent_client_protocol::SessionId::new("depth-1-id"),
4815 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-2"),
4816 depth: 2,
4817 summary_prompt: "Summarize depth 2".to_string(),
4818 context_low_prompt: "Context low depth 2".to_string(),
4819 };
4820
4821 let depth_2_subagent = cx.new(|cx| {
4822 Thread::new_subagent(
4823 project.clone(),
4824 project_context.clone(),
4825 context_server_registry.clone(),
4826 Templates::new(),
4827 model.clone(),
4828 depth_2_context,
4829 std::collections::BTreeMap::new(),
4830 cx,
4831 )
4832 });
4833
4834 depth_2_subagent.read_with(cx, |thread, _| {
4835 assert_eq!(thread.depth(), 2);
4836 assert!(thread.is_subagent());
4837 });
4838
4839 depth_2_subagent
4840 .update(cx, |thread, cx| {
4841 thread.submit_user_message("Nested task", cx)
4842 })
4843 .unwrap();
4844 cx.run_until_parked();
4845
4846 let pending = model.as_fake().pending_completions();
4847 assert!(
4848 !pending.is_empty(),
4849 "depth-2 subagent should be able to submit messages"
4850 );
4851}
4852
4853#[gpui::test]
4854async fn test_subagent_uses_tool_and_returns_result(cx: &mut TestAppContext) {
4855 init_test(cx);
4856 always_allow_tools(cx);
4857
4858 cx.update(|cx| {
4859 cx.update_flags(true, vec!["subagents".to_string()]);
4860 });
4861
4862 let fs = FakeFs::new(cx.executor());
4863 fs.insert_tree(path!("/test"), json!({})).await;
4864 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4865 let project_context = cx.new(|_cx| ProjectContext::default());
4866 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4867 let context_server_registry =
4868 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4869 let model = Arc::new(FakeLanguageModel::default());
4870 let fake_model = model.as_fake();
4871
4872 let subagent_context = SubagentContext {
4873 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4874 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4875 depth: 1,
4876 summary_prompt: "Summarize what you did".to_string(),
4877 context_low_prompt: "Context low".to_string(),
4878 };
4879
4880 let subagent = cx.new(|cx| {
4881 let mut thread = Thread::new_subagent(
4882 project.clone(),
4883 project_context,
4884 context_server_registry,
4885 Templates::new(),
4886 model.clone(),
4887 subagent_context,
4888 std::collections::BTreeMap::new(),
4889 cx,
4890 );
4891 thread.add_tool(EchoTool);
4892 thread
4893 });
4894
4895 subagent.read_with(cx, |thread, _| {
4896 assert!(
4897 thread.has_registered_tool("echo"),
4898 "subagent should have echo tool"
4899 );
4900 });
4901
4902 subagent
4903 .update(cx, |thread, cx| {
4904 thread.submit_user_message("Use the echo tool to echo 'hello world'", cx)
4905 })
4906 .unwrap();
4907 cx.run_until_parked();
4908
4909 let tool_use = LanguageModelToolUse {
4910 id: "tool_call_1".into(),
4911 name: EchoTool::name().into(),
4912 raw_input: json!({"text": "hello world"}).to_string(),
4913 input: json!({"text": "hello world"}),
4914 is_input_complete: true,
4915 thought_signature: None,
4916 };
4917 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
4918 fake_model.end_last_completion_stream();
4919 cx.run_until_parked();
4920
4921 let pending = fake_model.pending_completions();
4922 assert!(
4923 !pending.is_empty(),
4924 "should have pending completion after tool use"
4925 );
4926
4927 let last_completion = pending.last().unwrap();
4928 let has_tool_result = last_completion.messages.iter().any(|m| {
4929 m.content
4930 .iter()
4931 .any(|c| matches!(c, MessageContent::ToolResult(_)))
4932 });
4933 assert!(
4934 has_tool_result,
4935 "tool result should be in the messages sent back to the model"
4936 );
4937}
4938
4939#[gpui::test]
4940async fn test_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4941 init_test(cx);
4942
4943 cx.update(|cx| {
4944 cx.update_flags(true, vec!["subagents".to_string()]);
4945 });
4946
4947 let fs = FakeFs::new(cx.executor());
4948 fs.insert_tree(path!("/test"), json!({})).await;
4949 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4950 let project_context = cx.new(|_cx| ProjectContext::default());
4951 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4952 let context_server_registry =
4953 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4954 let model = Arc::new(FakeLanguageModel::default());
4955
4956 let parent = cx.new(|cx| {
4957 Thread::new(
4958 project.clone(),
4959 project_context.clone(),
4960 context_server_registry.clone(),
4961 Templates::new(),
4962 Some(model.clone()),
4963 cx,
4964 )
4965 });
4966
4967 let mut subagents = Vec::new();
4968 for i in 0..MAX_PARALLEL_SUBAGENTS {
4969 let subagent_context = SubagentContext {
4970 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4971 tool_use_id: language_model::LanguageModelToolUseId::from(format!("tool-use-{}", i)),
4972 depth: 1,
4973 summary_prompt: "Summarize".to_string(),
4974 context_low_prompt: "Context low".to_string(),
4975 };
4976
4977 let subagent = cx.new(|cx| {
4978 Thread::new_subagent(
4979 project.clone(),
4980 project_context.clone(),
4981 context_server_registry.clone(),
4982 Templates::new(),
4983 model.clone(),
4984 subagent_context,
4985 std::collections::BTreeMap::new(),
4986 cx,
4987 )
4988 });
4989
4990 parent.update(cx, |thread, _cx| {
4991 thread.register_running_subagent(subagent.downgrade());
4992 });
4993 subagents.push(subagent);
4994 }
4995
4996 parent.read_with(cx, |thread, _| {
4997 assert_eq!(
4998 thread.running_subagent_count(),
4999 MAX_PARALLEL_SUBAGENTS,
5000 "should have MAX_PARALLEL_SUBAGENTS registered"
5001 );
5002 });
5003
5004 let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
5005 std::collections::BTreeMap::new();
5006
5007 #[allow(clippy::arc_with_non_send_sync)]
5008 let tool = Arc::new(SubagentTool::new(
5009 parent.downgrade(),
5010 project.clone(),
5011 project_context,
5012 context_server_registry,
5013 Templates::new(),
5014 0,
5015 parent_tools,
5016 ));
5017
5018 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5019
5020 let result = cx.update(|cx| {
5021 tool.run(
5022 SubagentToolInput {
5023 subagents: vec![crate::SubagentConfig {
5024 label: "Test".to_string(),
5025 task_prompt: "Do something".to_string(),
5026 summary_prompt: "Summarize".to_string(),
5027 context_low_prompt: "Context low".to_string(),
5028 timeout_ms: None,
5029 allowed_tools: None,
5030 }],
5031 },
5032 event_stream,
5033 cx,
5034 )
5035 });
5036
5037 let err = result.await.unwrap_err();
5038 assert!(
5039 err.to_string().contains("Maximum parallel subagents"),
5040 "should reject when max parallel subagents reached: {}",
5041 err
5042 );
5043
5044 drop(subagents);
5045}
5046
5047#[gpui::test]
5048async fn test_subagent_tool_end_to_end(cx: &mut TestAppContext) {
5049 init_test(cx);
5050 always_allow_tools(cx);
5051
5052 cx.update(|cx| {
5053 cx.update_flags(true, vec!["subagents".to_string()]);
5054 });
5055
5056 let fs = FakeFs::new(cx.executor());
5057 fs.insert_tree(path!("/test"), json!({})).await;
5058 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
5059 let project_context = cx.new(|_cx| ProjectContext::default());
5060 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5061 let context_server_registry =
5062 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5063 let model = Arc::new(FakeLanguageModel::default());
5064 let fake_model = model.as_fake();
5065
5066 let parent = cx.new(|cx| {
5067 let mut thread = Thread::new(
5068 project.clone(),
5069 project_context.clone(),
5070 context_server_registry.clone(),
5071 Templates::new(),
5072 Some(model.clone()),
5073 cx,
5074 );
5075 thread.add_tool(EchoTool);
5076 thread
5077 });
5078
5079 let mut parent_tools: std::collections::BTreeMap<
5080 gpui::SharedString,
5081 Arc<dyn crate::AnyAgentTool>,
5082 > = std::collections::BTreeMap::new();
5083 parent_tools.insert("echo".into(), EchoTool.erase());
5084
5085 #[allow(clippy::arc_with_non_send_sync)]
5086 let tool = Arc::new(SubagentTool::new(
5087 parent.downgrade(),
5088 project.clone(),
5089 project_context,
5090 context_server_registry,
5091 Templates::new(),
5092 0,
5093 parent_tools,
5094 ));
5095
5096 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5097
5098 let task = cx.update(|cx| {
5099 tool.run(
5100 SubagentToolInput {
5101 subagents: vec![crate::SubagentConfig {
5102 label: "Research task".to_string(),
5103 task_prompt: "Find all TODOs in the codebase".to_string(),
5104 summary_prompt: "Summarize what you found".to_string(),
5105 context_low_prompt: "Context low, wrap up".to_string(),
5106 timeout_ms: None,
5107 allowed_tools: None,
5108 }],
5109 },
5110 event_stream,
5111 cx,
5112 )
5113 });
5114
5115 cx.run_until_parked();
5116
5117 let pending = fake_model.pending_completions();
5118 assert!(
5119 !pending.is_empty(),
5120 "subagent should have started and sent a completion request"
5121 );
5122
5123 let first_completion = &pending[0];
5124 let has_task_prompt = first_completion.messages.iter().any(|m| {
5125 m.role == language_model::Role::User
5126 && m.content
5127 .iter()
5128 .any(|c| c.to_str().map(|s| s.contains("TODO")).unwrap_or(false))
5129 });
5130 assert!(has_task_prompt, "task prompt should be sent to subagent");
5131
5132 fake_model.send_last_completion_stream_text_chunk("I found 5 TODOs in the codebase.");
5133 fake_model
5134 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
5135 fake_model.end_last_completion_stream();
5136 cx.run_until_parked();
5137
5138 let pending = fake_model.pending_completions();
5139 assert!(
5140 !pending.is_empty(),
5141 "should have pending completion for summary request"
5142 );
5143
5144 let last_completion = pending.last().unwrap();
5145 let has_summary_prompt = last_completion.messages.iter().any(|m| {
5146 m.role == language_model::Role::User
5147 && m.content.iter().any(|c| {
5148 c.to_str()
5149 .map(|s| s.contains("Summarize") || s.contains("summarize"))
5150 .unwrap_or(false)
5151 })
5152 });
5153 assert!(
5154 has_summary_prompt,
5155 "summary prompt should be sent after task completion"
5156 );
5157
5158 fake_model.send_last_completion_stream_text_chunk("Summary: Found 5 TODOs across 3 files.");
5159 fake_model
5160 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
5161 fake_model.end_last_completion_stream();
5162 cx.run_until_parked();
5163
5164 let result = task.await;
5165 assert!(result.is_ok(), "subagent tool should complete successfully");
5166
5167 let summary = result.unwrap();
5168 assert!(
5169 summary.contains("Summary") || summary.contains("TODO") || summary.contains("5"),
5170 "summary should contain subagent's response: {}",
5171 summary
5172 );
5173}
5174
5175#[gpui::test]
5176async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5177 init_test(cx);
5178
5179 let fs = FakeFs::new(cx.executor());
5180 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5181 .await;
5182 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5183
5184 cx.update(|cx| {
5185 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5186 settings.tool_permissions.tools.insert(
5187 "edit_file".into(),
5188 agent_settings::ToolRules {
5189 default_mode: settings::ToolPermissionMode::Allow,
5190 always_allow: vec![],
5191 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5192 always_confirm: vec![],
5193 invalid_patterns: vec![],
5194 },
5195 );
5196 agent_settings::AgentSettings::override_global(settings, cx);
5197 });
5198
5199 let context_server_registry =
5200 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5201 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5202 let templates = crate::Templates::new();
5203 let thread = cx.new(|cx| {
5204 crate::Thread::new(
5205 project.clone(),
5206 cx.new(|_cx| prompt_store::ProjectContext::default()),
5207 context_server_registry,
5208 templates.clone(),
5209 None,
5210 cx,
5211 )
5212 });
5213
5214 #[allow(clippy::arc_with_non_send_sync)]
5215 let tool = Arc::new(crate::EditFileTool::new(
5216 project.clone(),
5217 thread.downgrade(),
5218 language_registry,
5219 templates,
5220 ));
5221 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5222
5223 let task = cx.update(|cx| {
5224 tool.run(
5225 crate::EditFileToolInput {
5226 display_description: "Edit sensitive file".to_string(),
5227 path: "root/sensitive_config.txt".into(),
5228 mode: crate::EditFileMode::Edit,
5229 },
5230 event_stream,
5231 cx,
5232 )
5233 });
5234
5235 let result = task.await;
5236 assert!(result.is_err(), "expected edit to be blocked");
5237 assert!(
5238 result.unwrap_err().to_string().contains("blocked"),
5239 "error should mention the edit was blocked"
5240 );
5241}
5242
5243#[gpui::test]
5244async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5245 init_test(cx);
5246
5247 let fs = FakeFs::new(cx.executor());
5248 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5249 .await;
5250 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5251
5252 cx.update(|cx| {
5253 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5254 settings.tool_permissions.tools.insert(
5255 "delete_path".into(),
5256 agent_settings::ToolRules {
5257 default_mode: settings::ToolPermissionMode::Allow,
5258 always_allow: vec![],
5259 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5260 always_confirm: vec![],
5261 invalid_patterns: vec![],
5262 },
5263 );
5264 agent_settings::AgentSettings::override_global(settings, cx);
5265 });
5266
5267 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5268
5269 #[allow(clippy::arc_with_non_send_sync)]
5270 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5271 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5272
5273 let task = cx.update(|cx| {
5274 tool.run(
5275 crate::DeletePathToolInput {
5276 path: "root/important_data.txt".to_string(),
5277 },
5278 event_stream,
5279 cx,
5280 )
5281 });
5282
5283 let result = task.await;
5284 assert!(result.is_err(), "expected deletion to be blocked");
5285 assert!(
5286 result.unwrap_err().to_string().contains("blocked"),
5287 "error should mention the deletion was blocked"
5288 );
5289}
5290
5291#[gpui::test]
5292async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5293 init_test(cx);
5294
5295 let fs = FakeFs::new(cx.executor());
5296 fs.insert_tree(
5297 "/root",
5298 json!({
5299 "safe.txt": "content",
5300 "protected": {}
5301 }),
5302 )
5303 .await;
5304 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5305
5306 cx.update(|cx| {
5307 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5308 settings.tool_permissions.tools.insert(
5309 "move_path".into(),
5310 agent_settings::ToolRules {
5311 default_mode: settings::ToolPermissionMode::Allow,
5312 always_allow: vec![],
5313 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5314 always_confirm: vec![],
5315 invalid_patterns: vec![],
5316 },
5317 );
5318 agent_settings::AgentSettings::override_global(settings, cx);
5319 });
5320
5321 #[allow(clippy::arc_with_non_send_sync)]
5322 let tool = Arc::new(crate::MovePathTool::new(project));
5323 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5324
5325 let task = cx.update(|cx| {
5326 tool.run(
5327 crate::MovePathToolInput {
5328 source_path: "root/safe.txt".to_string(),
5329 destination_path: "root/protected/safe.txt".to_string(),
5330 },
5331 event_stream,
5332 cx,
5333 )
5334 });
5335
5336 let result = task.await;
5337 assert!(
5338 result.is_err(),
5339 "expected move to be blocked due to destination path"
5340 );
5341 assert!(
5342 result.unwrap_err().to_string().contains("blocked"),
5343 "error should mention the move was blocked"
5344 );
5345}
5346
5347#[gpui::test]
5348async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5349 init_test(cx);
5350
5351 let fs = FakeFs::new(cx.executor());
5352 fs.insert_tree(
5353 "/root",
5354 json!({
5355 "secret.txt": "secret content",
5356 "public": {}
5357 }),
5358 )
5359 .await;
5360 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5361
5362 cx.update(|cx| {
5363 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5364 settings.tool_permissions.tools.insert(
5365 "move_path".into(),
5366 agent_settings::ToolRules {
5367 default_mode: settings::ToolPermissionMode::Allow,
5368 always_allow: vec![],
5369 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5370 always_confirm: vec![],
5371 invalid_patterns: vec![],
5372 },
5373 );
5374 agent_settings::AgentSettings::override_global(settings, cx);
5375 });
5376
5377 #[allow(clippy::arc_with_non_send_sync)]
5378 let tool = Arc::new(crate::MovePathTool::new(project));
5379 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5380
5381 let task = cx.update(|cx| {
5382 tool.run(
5383 crate::MovePathToolInput {
5384 source_path: "root/secret.txt".to_string(),
5385 destination_path: "root/public/not_secret.txt".to_string(),
5386 },
5387 event_stream,
5388 cx,
5389 )
5390 });
5391
5392 let result = task.await;
5393 assert!(
5394 result.is_err(),
5395 "expected move to be blocked due to source path"
5396 );
5397 assert!(
5398 result.unwrap_err().to_string().contains("blocked"),
5399 "error should mention the move was blocked"
5400 );
5401}
5402
5403#[gpui::test]
5404async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5405 init_test(cx);
5406
5407 let fs = FakeFs::new(cx.executor());
5408 fs.insert_tree(
5409 "/root",
5410 json!({
5411 "confidential.txt": "confidential data",
5412 "dest": {}
5413 }),
5414 )
5415 .await;
5416 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5417
5418 cx.update(|cx| {
5419 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5420 settings.tool_permissions.tools.insert(
5421 "copy_path".into(),
5422 agent_settings::ToolRules {
5423 default_mode: settings::ToolPermissionMode::Allow,
5424 always_allow: vec![],
5425 always_deny: vec![
5426 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5427 ],
5428 always_confirm: vec![],
5429 invalid_patterns: vec![],
5430 },
5431 );
5432 agent_settings::AgentSettings::override_global(settings, cx);
5433 });
5434
5435 #[allow(clippy::arc_with_non_send_sync)]
5436 let tool = Arc::new(crate::CopyPathTool::new(project));
5437 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5438
5439 let task = cx.update(|cx| {
5440 tool.run(
5441 crate::CopyPathToolInput {
5442 source_path: "root/confidential.txt".to_string(),
5443 destination_path: "root/dest/copy.txt".to_string(),
5444 },
5445 event_stream,
5446 cx,
5447 )
5448 });
5449
5450 let result = task.await;
5451 assert!(result.is_err(), "expected copy to be blocked");
5452 assert!(
5453 result.unwrap_err().to_string().contains("blocked"),
5454 "error should mention the copy was blocked"
5455 );
5456}
5457
5458#[gpui::test]
5459async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5460 init_test(cx);
5461
5462 let fs = FakeFs::new(cx.executor());
5463 fs.insert_tree(
5464 "/root",
5465 json!({
5466 "normal.txt": "normal content",
5467 "readonly": {
5468 "config.txt": "readonly content"
5469 }
5470 }),
5471 )
5472 .await;
5473 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5474
5475 cx.update(|cx| {
5476 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5477 settings.tool_permissions.tools.insert(
5478 "save_file".into(),
5479 agent_settings::ToolRules {
5480 default_mode: settings::ToolPermissionMode::Allow,
5481 always_allow: vec![],
5482 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5483 always_confirm: vec![],
5484 invalid_patterns: vec![],
5485 },
5486 );
5487 agent_settings::AgentSettings::override_global(settings, cx);
5488 });
5489
5490 #[allow(clippy::arc_with_non_send_sync)]
5491 let tool = Arc::new(crate::SaveFileTool::new(project));
5492 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5493
5494 let task = cx.update(|cx| {
5495 tool.run(
5496 crate::SaveFileToolInput {
5497 paths: vec![
5498 std::path::PathBuf::from("root/normal.txt"),
5499 std::path::PathBuf::from("root/readonly/config.txt"),
5500 ],
5501 },
5502 event_stream,
5503 cx,
5504 )
5505 });
5506
5507 let result = task.await;
5508 assert!(
5509 result.is_err(),
5510 "expected save to be blocked due to denied path"
5511 );
5512 assert!(
5513 result.unwrap_err().to_string().contains("blocked"),
5514 "error should mention the save was blocked"
5515 );
5516}
5517
5518#[gpui::test]
5519async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5520 init_test(cx);
5521
5522 let fs = FakeFs::new(cx.executor());
5523 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5524 .await;
5525 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5526
5527 cx.update(|cx| {
5528 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5529 settings.always_allow_tool_actions = false;
5530 settings.tool_permissions.tools.insert(
5531 "save_file".into(),
5532 agent_settings::ToolRules {
5533 default_mode: settings::ToolPermissionMode::Allow,
5534 always_allow: vec![],
5535 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5536 always_confirm: vec![],
5537 invalid_patterns: vec![],
5538 },
5539 );
5540 agent_settings::AgentSettings::override_global(settings, cx);
5541 });
5542
5543 #[allow(clippy::arc_with_non_send_sync)]
5544 let tool = Arc::new(crate::SaveFileTool::new(project));
5545 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5546
5547 let task = cx.update(|cx| {
5548 tool.run(
5549 crate::SaveFileToolInput {
5550 paths: vec![std::path::PathBuf::from("root/config.secret")],
5551 },
5552 event_stream,
5553 cx,
5554 )
5555 });
5556
5557 let result = task.await;
5558 assert!(result.is_err(), "expected save to be blocked");
5559 assert!(
5560 result.unwrap_err().to_string().contains("blocked"),
5561 "error should mention the save was blocked"
5562 );
5563}
5564
5565#[gpui::test]
5566async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5567 init_test(cx);
5568
5569 cx.update(|cx| {
5570 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5571 settings.tool_permissions.tools.insert(
5572 "web_search".into(),
5573 agent_settings::ToolRules {
5574 default_mode: settings::ToolPermissionMode::Allow,
5575 always_allow: vec![],
5576 always_deny: vec![
5577 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5578 ],
5579 always_confirm: vec![],
5580 invalid_patterns: vec![],
5581 },
5582 );
5583 agent_settings::AgentSettings::override_global(settings, cx);
5584 });
5585
5586 #[allow(clippy::arc_with_non_send_sync)]
5587 let tool = Arc::new(crate::WebSearchTool);
5588 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5589
5590 let input: crate::WebSearchToolInput =
5591 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5592
5593 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5594
5595 let result = task.await;
5596 assert!(result.is_err(), "expected search to be blocked");
5597 assert!(
5598 result.unwrap_err().to_string().contains("blocked"),
5599 "error should mention the search was blocked"
5600 );
5601}
5602
5603#[gpui::test]
5604async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5605 init_test(cx);
5606
5607 let fs = FakeFs::new(cx.executor());
5608 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5609 .await;
5610 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5611
5612 cx.update(|cx| {
5613 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5614 settings.always_allow_tool_actions = false;
5615 settings.tool_permissions.tools.insert(
5616 "edit_file".into(),
5617 agent_settings::ToolRules {
5618 default_mode: settings::ToolPermissionMode::Confirm,
5619 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5620 always_deny: vec![],
5621 always_confirm: vec![],
5622 invalid_patterns: vec![],
5623 },
5624 );
5625 agent_settings::AgentSettings::override_global(settings, cx);
5626 });
5627
5628 let context_server_registry =
5629 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5630 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5631 let templates = crate::Templates::new();
5632 let thread = cx.new(|cx| {
5633 crate::Thread::new(
5634 project.clone(),
5635 cx.new(|_cx| prompt_store::ProjectContext::default()),
5636 context_server_registry,
5637 templates.clone(),
5638 None,
5639 cx,
5640 )
5641 });
5642
5643 #[allow(clippy::arc_with_non_send_sync)]
5644 let tool = Arc::new(crate::EditFileTool::new(
5645 project,
5646 thread.downgrade(),
5647 language_registry,
5648 templates,
5649 ));
5650 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5651
5652 let _task = cx.update(|cx| {
5653 tool.run(
5654 crate::EditFileToolInput {
5655 display_description: "Edit README".to_string(),
5656 path: "root/README.md".into(),
5657 mode: crate::EditFileMode::Edit,
5658 },
5659 event_stream,
5660 cx,
5661 )
5662 });
5663
5664 cx.run_until_parked();
5665
5666 let event = rx.try_next();
5667 assert!(
5668 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5669 "expected no authorization request for allowed .md file"
5670 );
5671}
5672
5673#[gpui::test]
5674async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5675 init_test(cx);
5676
5677 cx.update(|cx| {
5678 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5679 settings.tool_permissions.tools.insert(
5680 "fetch".into(),
5681 agent_settings::ToolRules {
5682 default_mode: settings::ToolPermissionMode::Allow,
5683 always_allow: vec![],
5684 always_deny: vec![
5685 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5686 ],
5687 always_confirm: vec![],
5688 invalid_patterns: vec![],
5689 },
5690 );
5691 agent_settings::AgentSettings::override_global(settings, cx);
5692 });
5693
5694 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5695
5696 #[allow(clippy::arc_with_non_send_sync)]
5697 let tool = Arc::new(crate::FetchTool::new(http_client));
5698 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5699
5700 let input: crate::FetchToolInput =
5701 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5702
5703 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5704
5705 let result = task.await;
5706 assert!(result.is_err(), "expected fetch to be blocked");
5707 assert!(
5708 result.unwrap_err().to_string().contains("blocked"),
5709 "error should mention the fetch was blocked"
5710 );
5711}
5712
5713#[gpui::test]
5714async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5715 init_test(cx);
5716
5717 cx.update(|cx| {
5718 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5719 settings.always_allow_tool_actions = false;
5720 settings.tool_permissions.tools.insert(
5721 "fetch".into(),
5722 agent_settings::ToolRules {
5723 default_mode: settings::ToolPermissionMode::Confirm,
5724 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5725 always_deny: vec![],
5726 always_confirm: vec![],
5727 invalid_patterns: vec![],
5728 },
5729 );
5730 agent_settings::AgentSettings::override_global(settings, cx);
5731 });
5732
5733 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5734
5735 #[allow(clippy::arc_with_non_send_sync)]
5736 let tool = Arc::new(crate::FetchTool::new(http_client));
5737 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5738
5739 let input: crate::FetchToolInput =
5740 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5741
5742 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5743
5744 cx.run_until_parked();
5745
5746 let event = rx.try_next();
5747 assert!(
5748 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5749 "expected no authorization request for allowed docs.rs URL"
5750 );
5751}