1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use feature_flags::FeatureFlagAppExt as _;
11use fs::{FakeFs, Fs};
12use futures::{
13 FutureExt as _, StreamExt,
14 channel::{
15 mpsc::{self, UnboundedReceiver},
16 oneshot,
17 },
18 future::{Fuse, Shared},
19};
20use gpui::{
21 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
22 http_client::FakeHttpClient,
23};
24use indoc::indoc;
25use language_model::{
26 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
27 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
28 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
29 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
30};
31use pretty_assertions::assert_eq;
32use project::{
33 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
34};
35use prompt_store::ProjectContext;
36use reqwest_client::ReqwestClient;
37use schemars::JsonSchema;
38use serde::{Deserialize, Serialize};
39use serde_json::json;
40use settings::{Settings, SettingsStore};
41use std::{
42 path::Path,
43 pin::Pin,
44 rc::Rc,
45 sync::{
46 Arc,
47 atomic::{AtomicBool, Ordering},
48 },
49 time::Duration,
50};
51use util::path;
52
53mod test_tools;
54use test_tools::*;
55
56fn init_test(cx: &mut TestAppContext) {
57 cx.update(|cx| {
58 let settings_store = SettingsStore::test(cx);
59 cx.set_global(settings_store);
60 });
61}
62
63struct FakeTerminalHandle {
64 killed: Arc<AtomicBool>,
65 stopped_by_user: Arc<AtomicBool>,
66 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
67 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
68 output: acp::TerminalOutputResponse,
69 id: acp::TerminalId,
70}
71
72impl FakeTerminalHandle {
73 fn new_never_exits(cx: &mut App) -> Self {
74 let killed = Arc::new(AtomicBool::new(false));
75 let stopped_by_user = Arc::new(AtomicBool::new(false));
76
77 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
78
79 let wait_for_exit = cx
80 .spawn(async move |_cx| {
81 // Wait for the exit signal (sent when kill() is called)
82 let _ = exit_receiver.await;
83 acp::TerminalExitStatus::new()
84 })
85 .shared();
86
87 Self {
88 killed,
89 stopped_by_user,
90 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
91 wait_for_exit,
92 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
93 id: acp::TerminalId::new("fake_terminal".to_string()),
94 }
95 }
96
97 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
98 let killed = Arc::new(AtomicBool::new(false));
99 let stopped_by_user = Arc::new(AtomicBool::new(false));
100 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
101
102 let wait_for_exit = cx
103 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
104 .shared();
105
106 Self {
107 killed,
108 stopped_by_user,
109 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
110 wait_for_exit,
111 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
112 id: acp::TerminalId::new("fake_terminal".to_string()),
113 }
114 }
115
116 fn was_killed(&self) -> bool {
117 self.killed.load(Ordering::SeqCst)
118 }
119
120 fn set_stopped_by_user(&self, stopped: bool) {
121 self.stopped_by_user.store(stopped, Ordering::SeqCst);
122 }
123
124 fn signal_exit(&self) {
125 if let Some(sender) = self.exit_sender.borrow_mut().take() {
126 let _ = sender.send(());
127 }
128 }
129}
130
131impl crate::TerminalHandle for FakeTerminalHandle {
132 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
133 Ok(self.id.clone())
134 }
135
136 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
137 Ok(self.output.clone())
138 }
139
140 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
141 Ok(self.wait_for_exit.clone())
142 }
143
144 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
145 self.killed.store(true, Ordering::SeqCst);
146 self.signal_exit();
147 Ok(())
148 }
149
150 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
151 Ok(self.stopped_by_user.load(Ordering::SeqCst))
152 }
153}
154
155struct FakeThreadEnvironment {
156 handle: Rc<FakeTerminalHandle>,
157}
158
159impl crate::ThreadEnvironment for FakeThreadEnvironment {
160 fn create_terminal(
161 &self,
162 _command: String,
163 _cwd: Option<std::path::PathBuf>,
164 _output_byte_limit: Option<u64>,
165 _cx: &mut AsyncApp,
166 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
167 Task::ready(Ok(self.handle.clone() as Rc<dyn crate::TerminalHandle>))
168 }
169}
170
171/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
172struct MultiTerminalEnvironment {
173 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
174}
175
176impl MultiTerminalEnvironment {
177 fn new() -> Self {
178 Self {
179 handles: std::cell::RefCell::new(Vec::new()),
180 }
181 }
182
183 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
184 self.handles.borrow().clone()
185 }
186}
187
188impl crate::ThreadEnvironment for MultiTerminalEnvironment {
189 fn create_terminal(
190 &self,
191 _command: String,
192 _cwd: Option<std::path::PathBuf>,
193 _output_byte_limit: Option<u64>,
194 cx: &mut AsyncApp,
195 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
196 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
197 self.handles.borrow_mut().push(handle.clone());
198 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
199 }
200}
201
202fn always_allow_tools(cx: &mut TestAppContext) {
203 cx.update(|cx| {
204 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
205 settings.always_allow_tool_actions = true;
206 agent_settings::AgentSettings::override_global(settings, cx);
207 });
208}
209
210#[gpui::test]
211async fn test_echo(cx: &mut TestAppContext) {
212 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
213 let fake_model = model.as_fake();
214
215 let events = thread
216 .update(cx, |thread, cx| {
217 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
218 })
219 .unwrap();
220 cx.run_until_parked();
221 fake_model.send_last_completion_stream_text_chunk("Hello");
222 fake_model
223 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
224 fake_model.end_last_completion_stream();
225
226 let events = events.collect().await;
227 thread.update(cx, |thread, _cx| {
228 assert_eq!(
229 thread.last_message().unwrap().to_markdown(),
230 indoc! {"
231 ## Assistant
232
233 Hello
234 "}
235 )
236 });
237 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
238}
239
240#[gpui::test]
241async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
242 init_test(cx);
243 always_allow_tools(cx);
244
245 let fs = FakeFs::new(cx.executor());
246 let project = Project::test(fs, [], cx).await;
247
248 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
249 let environment = Rc::new(FakeThreadEnvironment {
250 handle: handle.clone(),
251 });
252
253 #[allow(clippy::arc_with_non_send_sync)]
254 let tool = Arc::new(crate::TerminalTool::new(project, environment));
255 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
256
257 let task = cx.update(|cx| {
258 tool.run(
259 crate::TerminalToolInput {
260 command: "sleep 1000".to_string(),
261 cd: ".".to_string(),
262 timeout_ms: Some(5),
263 },
264 event_stream,
265 cx,
266 )
267 });
268
269 let update = rx.expect_update_fields().await;
270 assert!(
271 update.content.iter().any(|blocks| {
272 blocks
273 .iter()
274 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
275 }),
276 "expected tool call update to include terminal content"
277 );
278
279 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
280
281 let deadline = std::time::Instant::now() + Duration::from_millis(500);
282 loop {
283 if let Some(result) = task_future.as_mut().now_or_never() {
284 let result = result.expect("terminal tool task should complete");
285
286 assert!(
287 handle.was_killed(),
288 "expected terminal handle to be killed on timeout"
289 );
290 assert!(
291 result.contains("partial output"),
292 "expected result to include terminal output, got: {result}"
293 );
294 return;
295 }
296
297 if std::time::Instant::now() >= deadline {
298 panic!("timed out waiting for terminal tool task to complete");
299 }
300
301 cx.run_until_parked();
302 cx.background_executor.timer(Duration::from_millis(1)).await;
303 }
304}
305
306#[gpui::test]
307#[ignore]
308async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
309 init_test(cx);
310 always_allow_tools(cx);
311
312 let fs = FakeFs::new(cx.executor());
313 let project = Project::test(fs, [], cx).await;
314
315 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
316 let environment = Rc::new(FakeThreadEnvironment {
317 handle: handle.clone(),
318 });
319
320 #[allow(clippy::arc_with_non_send_sync)]
321 let tool = Arc::new(crate::TerminalTool::new(project, environment));
322 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
323
324 let _task = cx.update(|cx| {
325 tool.run(
326 crate::TerminalToolInput {
327 command: "sleep 1000".to_string(),
328 cd: ".".to_string(),
329 timeout_ms: None,
330 },
331 event_stream,
332 cx,
333 )
334 });
335
336 let update = rx.expect_update_fields().await;
337 assert!(
338 update.content.iter().any(|blocks| {
339 blocks
340 .iter()
341 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
342 }),
343 "expected tool call update to include terminal content"
344 );
345
346 cx.background_executor
347 .timer(Duration::from_millis(25))
348 .await;
349
350 assert!(
351 !handle.was_killed(),
352 "did not expect terminal handle to be killed without a timeout"
353 );
354}
355
356#[gpui::test]
357async fn test_thinking(cx: &mut TestAppContext) {
358 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
359 let fake_model = model.as_fake();
360
361 let events = thread
362 .update(cx, |thread, cx| {
363 thread.send(
364 UserMessageId::new(),
365 [indoc! {"
366 Testing:
367
368 Generate a thinking step where you just think the word 'Think',
369 and have your final answer be 'Hello'
370 "}],
371 cx,
372 )
373 })
374 .unwrap();
375 cx.run_until_parked();
376 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
377 text: "Think".to_string(),
378 signature: None,
379 });
380 fake_model.send_last_completion_stream_text_chunk("Hello");
381 fake_model
382 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
383 fake_model.end_last_completion_stream();
384
385 let events = events.collect().await;
386 thread.update(cx, |thread, _cx| {
387 assert_eq!(
388 thread.last_message().unwrap().to_markdown(),
389 indoc! {"
390 ## Assistant
391
392 <think>Think</think>
393 Hello
394 "}
395 )
396 });
397 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
398}
399
400#[gpui::test]
401async fn test_system_prompt(cx: &mut TestAppContext) {
402 let ThreadTest {
403 model,
404 thread,
405 project_context,
406 ..
407 } = setup(cx, TestModel::Fake).await;
408 let fake_model = model.as_fake();
409
410 project_context.update(cx, |project_context, _cx| {
411 project_context.shell = "test-shell".into()
412 });
413 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
414 thread
415 .update(cx, |thread, cx| {
416 thread.send(UserMessageId::new(), ["abc"], cx)
417 })
418 .unwrap();
419 cx.run_until_parked();
420 let mut pending_completions = fake_model.pending_completions();
421 assert_eq!(
422 pending_completions.len(),
423 1,
424 "unexpected pending completions: {:?}",
425 pending_completions
426 );
427
428 let pending_completion = pending_completions.pop().unwrap();
429 assert_eq!(pending_completion.messages[0].role, Role::System);
430
431 let system_message = &pending_completion.messages[0];
432 let system_prompt = system_message.content[0].to_str().unwrap();
433 assert!(
434 system_prompt.contains("test-shell"),
435 "unexpected system message: {:?}",
436 system_message
437 );
438 assert!(
439 system_prompt.contains("## Fixing Diagnostics"),
440 "unexpected system message: {:?}",
441 system_message
442 );
443}
444
445#[gpui::test]
446async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
447 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
448 let fake_model = model.as_fake();
449
450 thread
451 .update(cx, |thread, cx| {
452 thread.send(UserMessageId::new(), ["abc"], cx)
453 })
454 .unwrap();
455 cx.run_until_parked();
456 let mut pending_completions = fake_model.pending_completions();
457 assert_eq!(
458 pending_completions.len(),
459 1,
460 "unexpected pending completions: {:?}",
461 pending_completions
462 );
463
464 let pending_completion = pending_completions.pop().unwrap();
465 assert_eq!(pending_completion.messages[0].role, Role::System);
466
467 let system_message = &pending_completion.messages[0];
468 let system_prompt = system_message.content[0].to_str().unwrap();
469 assert!(
470 !system_prompt.contains("## Tool Use"),
471 "unexpected system message: {:?}",
472 system_message
473 );
474 assert!(
475 !system_prompt.contains("## Fixing Diagnostics"),
476 "unexpected system message: {:?}",
477 system_message
478 );
479}
480
481#[gpui::test]
482async fn test_prompt_caching(cx: &mut TestAppContext) {
483 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
484 let fake_model = model.as_fake();
485
486 // Send initial user message and verify it's cached
487 thread
488 .update(cx, |thread, cx| {
489 thread.send(UserMessageId::new(), ["Message 1"], cx)
490 })
491 .unwrap();
492 cx.run_until_parked();
493
494 let completion = fake_model.pending_completions().pop().unwrap();
495 assert_eq!(
496 completion.messages[1..],
497 vec![LanguageModelRequestMessage {
498 role: Role::User,
499 content: vec!["Message 1".into()],
500 cache: true,
501 reasoning_details: None,
502 }]
503 );
504 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
505 "Response to Message 1".into(),
506 ));
507 fake_model.end_last_completion_stream();
508 cx.run_until_parked();
509
510 // Send another user message and verify only the latest is cached
511 thread
512 .update(cx, |thread, cx| {
513 thread.send(UserMessageId::new(), ["Message 2"], cx)
514 })
515 .unwrap();
516 cx.run_until_parked();
517
518 let completion = fake_model.pending_completions().pop().unwrap();
519 assert_eq!(
520 completion.messages[1..],
521 vec![
522 LanguageModelRequestMessage {
523 role: Role::User,
524 content: vec!["Message 1".into()],
525 cache: false,
526 reasoning_details: None,
527 },
528 LanguageModelRequestMessage {
529 role: Role::Assistant,
530 content: vec!["Response to Message 1".into()],
531 cache: false,
532 reasoning_details: None,
533 },
534 LanguageModelRequestMessage {
535 role: Role::User,
536 content: vec!["Message 2".into()],
537 cache: true,
538 reasoning_details: None,
539 }
540 ]
541 );
542 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
543 "Response to Message 2".into(),
544 ));
545 fake_model.end_last_completion_stream();
546 cx.run_until_parked();
547
548 // Simulate a tool call and verify that the latest tool result is cached
549 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
550 thread
551 .update(cx, |thread, cx| {
552 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
553 })
554 .unwrap();
555 cx.run_until_parked();
556
557 let tool_use = LanguageModelToolUse {
558 id: "tool_1".into(),
559 name: EchoTool::name().into(),
560 raw_input: json!({"text": "test"}).to_string(),
561 input: json!({"text": "test"}),
562 is_input_complete: true,
563 thought_signature: None,
564 };
565 fake_model
566 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
567 fake_model.end_last_completion_stream();
568 cx.run_until_parked();
569
570 let completion = fake_model.pending_completions().pop().unwrap();
571 let tool_result = LanguageModelToolResult {
572 tool_use_id: "tool_1".into(),
573 tool_name: EchoTool::name().into(),
574 is_error: false,
575 content: "test".into(),
576 output: Some("test".into()),
577 };
578 assert_eq!(
579 completion.messages[1..],
580 vec![
581 LanguageModelRequestMessage {
582 role: Role::User,
583 content: vec!["Message 1".into()],
584 cache: false,
585 reasoning_details: None,
586 },
587 LanguageModelRequestMessage {
588 role: Role::Assistant,
589 content: vec!["Response to Message 1".into()],
590 cache: false,
591 reasoning_details: None,
592 },
593 LanguageModelRequestMessage {
594 role: Role::User,
595 content: vec!["Message 2".into()],
596 cache: false,
597 reasoning_details: None,
598 },
599 LanguageModelRequestMessage {
600 role: Role::Assistant,
601 content: vec!["Response to Message 2".into()],
602 cache: false,
603 reasoning_details: None,
604 },
605 LanguageModelRequestMessage {
606 role: Role::User,
607 content: vec!["Use the echo tool".into()],
608 cache: false,
609 reasoning_details: None,
610 },
611 LanguageModelRequestMessage {
612 role: Role::Assistant,
613 content: vec![MessageContent::ToolUse(tool_use)],
614 cache: false,
615 reasoning_details: None,
616 },
617 LanguageModelRequestMessage {
618 role: Role::User,
619 content: vec![MessageContent::ToolResult(tool_result)],
620 cache: true,
621 reasoning_details: None,
622 }
623 ]
624 );
625}
626
627#[gpui::test]
628#[cfg_attr(not(feature = "e2e"), ignore)]
629async fn test_basic_tool_calls(cx: &mut TestAppContext) {
630 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
631
632 // Test a tool call that's likely to complete *before* streaming stops.
633 let events = thread
634 .update(cx, |thread, cx| {
635 thread.add_tool(EchoTool);
636 thread.send(
637 UserMessageId::new(),
638 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
639 cx,
640 )
641 })
642 .unwrap()
643 .collect()
644 .await;
645 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
646
647 // Test a tool calls that's likely to complete *after* streaming stops.
648 let events = thread
649 .update(cx, |thread, cx| {
650 thread.remove_tool(&EchoTool::name());
651 thread.add_tool(DelayTool);
652 thread.send(
653 UserMessageId::new(),
654 [
655 "Now call the delay tool with 200ms.",
656 "When the timer goes off, then you echo the output of the tool.",
657 ],
658 cx,
659 )
660 })
661 .unwrap()
662 .collect()
663 .await;
664 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
665 thread.update(cx, |thread, _cx| {
666 assert!(
667 thread
668 .last_message()
669 .unwrap()
670 .as_agent_message()
671 .unwrap()
672 .content
673 .iter()
674 .any(|content| {
675 if let AgentMessageContent::Text(text) = content {
676 text.contains("Ding")
677 } else {
678 false
679 }
680 }),
681 "{}",
682 thread.to_markdown()
683 );
684 });
685}
686
687#[gpui::test]
688#[cfg_attr(not(feature = "e2e"), ignore)]
689async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
690 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
691
692 // Test a tool call that's likely to complete *before* streaming stops.
693 let mut events = thread
694 .update(cx, |thread, cx| {
695 thread.add_tool(WordListTool);
696 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
697 })
698 .unwrap();
699
700 let mut saw_partial_tool_use = false;
701 while let Some(event) = events.next().await {
702 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
703 thread.update(cx, |thread, _cx| {
704 // Look for a tool use in the thread's last message
705 let message = thread.last_message().unwrap();
706 let agent_message = message.as_agent_message().unwrap();
707 let last_content = agent_message.content.last().unwrap();
708 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
709 assert_eq!(last_tool_use.name.as_ref(), "word_list");
710 if tool_call.status == acp::ToolCallStatus::Pending {
711 if !last_tool_use.is_input_complete
712 && last_tool_use.input.get("g").is_none()
713 {
714 saw_partial_tool_use = true;
715 }
716 } else {
717 last_tool_use
718 .input
719 .get("a")
720 .expect("'a' has streamed because input is now complete");
721 last_tool_use
722 .input
723 .get("g")
724 .expect("'g' has streamed because input is now complete");
725 }
726 } else {
727 panic!("last content should be a tool use");
728 }
729 });
730 }
731 }
732
733 assert!(
734 saw_partial_tool_use,
735 "should see at least one partially streamed tool use in the history"
736 );
737}
738
739#[gpui::test]
740async fn test_tool_authorization(cx: &mut TestAppContext) {
741 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
742 let fake_model = model.as_fake();
743
744 let mut events = thread
745 .update(cx, |thread, cx| {
746 thread.add_tool(ToolRequiringPermission);
747 thread.send(UserMessageId::new(), ["abc"], cx)
748 })
749 .unwrap();
750 cx.run_until_parked();
751 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
752 LanguageModelToolUse {
753 id: "tool_id_1".into(),
754 name: ToolRequiringPermission::name().into(),
755 raw_input: "{}".into(),
756 input: json!({}),
757 is_input_complete: true,
758 thought_signature: None,
759 },
760 ));
761 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
762 LanguageModelToolUse {
763 id: "tool_id_2".into(),
764 name: ToolRequiringPermission::name().into(),
765 raw_input: "{}".into(),
766 input: json!({}),
767 is_input_complete: true,
768 thought_signature: None,
769 },
770 ));
771 fake_model.end_last_completion_stream();
772 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
773 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
774
775 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
776 tool_call_auth_1
777 .response
778 .send(acp::PermissionOptionId::new("allow"))
779 .unwrap();
780 cx.run_until_parked();
781
782 // Reject the second - send "deny" option_id directly since Deny is now a button
783 tool_call_auth_2
784 .response
785 .send(acp::PermissionOptionId::new("deny"))
786 .unwrap();
787 cx.run_until_parked();
788
789 let completion = fake_model.pending_completions().pop().unwrap();
790 let message = completion.messages.last().unwrap();
791 assert_eq!(
792 message.content,
793 vec![
794 language_model::MessageContent::ToolResult(LanguageModelToolResult {
795 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
796 tool_name: ToolRequiringPermission::name().into(),
797 is_error: false,
798 content: "Allowed".into(),
799 output: Some("Allowed".into())
800 }),
801 language_model::MessageContent::ToolResult(LanguageModelToolResult {
802 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
803 tool_name: ToolRequiringPermission::name().into(),
804 is_error: true,
805 content: "Permission to run tool denied by user".into(),
806 output: Some("Permission to run tool denied by user".into())
807 })
808 ]
809 );
810
811 // Simulate yet another tool call.
812 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
813 LanguageModelToolUse {
814 id: "tool_id_3".into(),
815 name: ToolRequiringPermission::name().into(),
816 raw_input: "{}".into(),
817 input: json!({}),
818 is_input_complete: true,
819 thought_signature: None,
820 },
821 ));
822 fake_model.end_last_completion_stream();
823
824 // Respond by always allowing tools - send transformed option_id
825 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
826 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
827 tool_call_auth_3
828 .response
829 .send(acp::PermissionOptionId::new(
830 "always_allow:tool_requiring_permission",
831 ))
832 .unwrap();
833 cx.run_until_parked();
834 let completion = fake_model.pending_completions().pop().unwrap();
835 let message = completion.messages.last().unwrap();
836 assert_eq!(
837 message.content,
838 vec![language_model::MessageContent::ToolResult(
839 LanguageModelToolResult {
840 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
841 tool_name: ToolRequiringPermission::name().into(),
842 is_error: false,
843 content: "Allowed".into(),
844 output: Some("Allowed".into())
845 }
846 )]
847 );
848
849 // Simulate a final tool call, ensuring we don't trigger authorization.
850 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
851 LanguageModelToolUse {
852 id: "tool_id_4".into(),
853 name: ToolRequiringPermission::name().into(),
854 raw_input: "{}".into(),
855 input: json!({}),
856 is_input_complete: true,
857 thought_signature: None,
858 },
859 ));
860 fake_model.end_last_completion_stream();
861 cx.run_until_parked();
862 let completion = fake_model.pending_completions().pop().unwrap();
863 let message = completion.messages.last().unwrap();
864 assert_eq!(
865 message.content,
866 vec![language_model::MessageContent::ToolResult(
867 LanguageModelToolResult {
868 tool_use_id: "tool_id_4".into(),
869 tool_name: ToolRequiringPermission::name().into(),
870 is_error: false,
871 content: "Allowed".into(),
872 output: Some("Allowed".into())
873 }
874 )]
875 );
876}
877
878#[gpui::test]
879async fn test_tool_hallucination(cx: &mut TestAppContext) {
880 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
881 let fake_model = model.as_fake();
882
883 let mut events = thread
884 .update(cx, |thread, cx| {
885 thread.send(UserMessageId::new(), ["abc"], cx)
886 })
887 .unwrap();
888 cx.run_until_parked();
889 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
890 LanguageModelToolUse {
891 id: "tool_id_1".into(),
892 name: "nonexistent_tool".into(),
893 raw_input: "{}".into(),
894 input: json!({}),
895 is_input_complete: true,
896 thought_signature: None,
897 },
898 ));
899 fake_model.end_last_completion_stream();
900
901 let tool_call = expect_tool_call(&mut events).await;
902 assert_eq!(tool_call.title, "nonexistent_tool");
903 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
904 let update = expect_tool_call_update_fields(&mut events).await;
905 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
906}
907
908#[gpui::test]
909async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
910 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
911 let fake_model = model.as_fake();
912
913 let events = thread
914 .update(cx, |thread, cx| {
915 thread.add_tool(EchoTool);
916 thread.send(UserMessageId::new(), ["abc"], cx)
917 })
918 .unwrap();
919 cx.run_until_parked();
920 let tool_use = LanguageModelToolUse {
921 id: "tool_id_1".into(),
922 name: EchoTool::name().into(),
923 raw_input: "{}".into(),
924 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
925 is_input_complete: true,
926 thought_signature: None,
927 };
928 fake_model
929 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
930 fake_model.end_last_completion_stream();
931
932 cx.run_until_parked();
933 let completion = fake_model.pending_completions().pop().unwrap();
934 let tool_result = LanguageModelToolResult {
935 tool_use_id: "tool_id_1".into(),
936 tool_name: EchoTool::name().into(),
937 is_error: false,
938 content: "def".into(),
939 output: Some("def".into()),
940 };
941 assert_eq!(
942 completion.messages[1..],
943 vec![
944 LanguageModelRequestMessage {
945 role: Role::User,
946 content: vec!["abc".into()],
947 cache: false,
948 reasoning_details: None,
949 },
950 LanguageModelRequestMessage {
951 role: Role::Assistant,
952 content: vec![MessageContent::ToolUse(tool_use.clone())],
953 cache: false,
954 reasoning_details: None,
955 },
956 LanguageModelRequestMessage {
957 role: Role::User,
958 content: vec![MessageContent::ToolResult(tool_result.clone())],
959 cache: true,
960 reasoning_details: None,
961 },
962 ]
963 );
964
965 // Simulate reaching tool use limit.
966 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
967 fake_model.end_last_completion_stream();
968 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
969 assert!(
970 last_event
971 .unwrap_err()
972 .is::<language_model::ToolUseLimitReachedError>()
973 );
974
975 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
976 cx.run_until_parked();
977 let completion = fake_model.pending_completions().pop().unwrap();
978 assert_eq!(
979 completion.messages[1..],
980 vec![
981 LanguageModelRequestMessage {
982 role: Role::User,
983 content: vec!["abc".into()],
984 cache: false,
985 reasoning_details: None,
986 },
987 LanguageModelRequestMessage {
988 role: Role::Assistant,
989 content: vec![MessageContent::ToolUse(tool_use)],
990 cache: false,
991 reasoning_details: None,
992 },
993 LanguageModelRequestMessage {
994 role: Role::User,
995 content: vec![MessageContent::ToolResult(tool_result)],
996 cache: false,
997 reasoning_details: None,
998 },
999 LanguageModelRequestMessage {
1000 role: Role::User,
1001 content: vec!["Continue where you left off".into()],
1002 cache: true,
1003 reasoning_details: None,
1004 }
1005 ]
1006 );
1007
1008 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
1009 fake_model.end_last_completion_stream();
1010 events.collect::<Vec<_>>().await;
1011 thread.read_with(cx, |thread, _cx| {
1012 assert_eq!(
1013 thread.last_message().unwrap().to_markdown(),
1014 indoc! {"
1015 ## Assistant
1016
1017 Done
1018 "}
1019 )
1020 });
1021}
1022
1023#[gpui::test]
1024async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
1025 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1026 let fake_model = model.as_fake();
1027
1028 let events = thread
1029 .update(cx, |thread, cx| {
1030 thread.add_tool(EchoTool);
1031 thread.send(UserMessageId::new(), ["abc"], cx)
1032 })
1033 .unwrap();
1034 cx.run_until_parked();
1035
1036 let tool_use = LanguageModelToolUse {
1037 id: "tool_id_1".into(),
1038 name: EchoTool::name().into(),
1039 raw_input: "{}".into(),
1040 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
1041 is_input_complete: true,
1042 thought_signature: None,
1043 };
1044 let tool_result = LanguageModelToolResult {
1045 tool_use_id: "tool_id_1".into(),
1046 tool_name: EchoTool::name().into(),
1047 is_error: false,
1048 content: "def".into(),
1049 output: Some("def".into()),
1050 };
1051 fake_model
1052 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
1053 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
1054 fake_model.end_last_completion_stream();
1055 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
1056 assert!(
1057 last_event
1058 .unwrap_err()
1059 .is::<language_model::ToolUseLimitReachedError>()
1060 );
1061
1062 thread
1063 .update(cx, |thread, cx| {
1064 thread.send(UserMessageId::new(), vec!["ghi"], cx)
1065 })
1066 .unwrap();
1067 cx.run_until_parked();
1068 let completion = fake_model.pending_completions().pop().unwrap();
1069 assert_eq!(
1070 completion.messages[1..],
1071 vec![
1072 LanguageModelRequestMessage {
1073 role: Role::User,
1074 content: vec!["abc".into()],
1075 cache: false,
1076 reasoning_details: None,
1077 },
1078 LanguageModelRequestMessage {
1079 role: Role::Assistant,
1080 content: vec![MessageContent::ToolUse(tool_use)],
1081 cache: false,
1082 reasoning_details: None,
1083 },
1084 LanguageModelRequestMessage {
1085 role: Role::User,
1086 content: vec![MessageContent::ToolResult(tool_result)],
1087 cache: false,
1088 reasoning_details: None,
1089 },
1090 LanguageModelRequestMessage {
1091 role: Role::User,
1092 content: vec!["ghi".into()],
1093 cache: true,
1094 reasoning_details: None,
1095 }
1096 ]
1097 );
1098}
1099
1100async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
1101 let event = events
1102 .next()
1103 .await
1104 .expect("no tool call authorization event received")
1105 .unwrap();
1106 match event {
1107 ThreadEvent::ToolCall(tool_call) => tool_call,
1108 event => {
1109 panic!("Unexpected event {event:?}");
1110 }
1111 }
1112}
1113
1114async fn expect_tool_call_update_fields(
1115 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1116) -> acp::ToolCallUpdate {
1117 let event = events
1118 .next()
1119 .await
1120 .expect("no tool call authorization event received")
1121 .unwrap();
1122 match event {
1123 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
1124 event => {
1125 panic!("Unexpected event {event:?}");
1126 }
1127 }
1128}
1129
1130async fn next_tool_call_authorization(
1131 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1132) -> ToolCallAuthorization {
1133 loop {
1134 let event = events
1135 .next()
1136 .await
1137 .expect("no tool call authorization event received")
1138 .unwrap();
1139 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1140 let permission_kinds = tool_call_authorization
1141 .options
1142 .iter()
1143 .map(|o| o.kind)
1144 .collect::<Vec<_>>();
1145 // Only 2 options now: AllowAlways (for tool) and AllowOnce (granularity only)
1146 // Deny is handled by the UI buttons, not as a separate option
1147 assert_eq!(
1148 permission_kinds,
1149 vec![
1150 acp::PermissionOptionKind::AllowAlways,
1151 acp::PermissionOptionKind::AllowOnce,
1152 ]
1153 );
1154 return tool_call_authorization;
1155 }
1156 }
1157}
1158
1159#[gpui::test]
1160#[cfg_attr(not(feature = "e2e"), ignore)]
1161async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1162 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1163
1164 // Test concurrent tool calls with different delay times
1165 let events = thread
1166 .update(cx, |thread, cx| {
1167 thread.add_tool(DelayTool);
1168 thread.send(
1169 UserMessageId::new(),
1170 [
1171 "Call the delay tool twice in the same message.",
1172 "Once with 100ms. Once with 300ms.",
1173 "When both timers are complete, describe the outputs.",
1174 ],
1175 cx,
1176 )
1177 })
1178 .unwrap()
1179 .collect()
1180 .await;
1181
1182 let stop_reasons = stop_events(events);
1183 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1184
1185 thread.update(cx, |thread, _cx| {
1186 let last_message = thread.last_message().unwrap();
1187 let agent_message = last_message.as_agent_message().unwrap();
1188 let text = agent_message
1189 .content
1190 .iter()
1191 .filter_map(|content| {
1192 if let AgentMessageContent::Text(text) = content {
1193 Some(text.as_str())
1194 } else {
1195 None
1196 }
1197 })
1198 .collect::<String>();
1199
1200 assert!(text.contains("Ding"));
1201 });
1202}
1203
1204#[gpui::test]
1205async fn test_profiles(cx: &mut TestAppContext) {
1206 let ThreadTest {
1207 model, thread, fs, ..
1208 } = setup(cx, TestModel::Fake).await;
1209 let fake_model = model.as_fake();
1210
1211 thread.update(cx, |thread, _cx| {
1212 thread.add_tool(DelayTool);
1213 thread.add_tool(EchoTool);
1214 thread.add_tool(InfiniteTool);
1215 });
1216
1217 // Override profiles and wait for settings to be loaded.
1218 fs.insert_file(
1219 paths::settings_file(),
1220 json!({
1221 "agent": {
1222 "profiles": {
1223 "test-1": {
1224 "name": "Test Profile 1",
1225 "tools": {
1226 EchoTool::name(): true,
1227 DelayTool::name(): true,
1228 }
1229 },
1230 "test-2": {
1231 "name": "Test Profile 2",
1232 "tools": {
1233 InfiniteTool::name(): true,
1234 }
1235 }
1236 }
1237 }
1238 })
1239 .to_string()
1240 .into_bytes(),
1241 )
1242 .await;
1243 cx.run_until_parked();
1244
1245 // Test that test-1 profile (default) has echo and delay tools
1246 thread
1247 .update(cx, |thread, cx| {
1248 thread.set_profile(AgentProfileId("test-1".into()), cx);
1249 thread.send(UserMessageId::new(), ["test"], cx)
1250 })
1251 .unwrap();
1252 cx.run_until_parked();
1253
1254 let mut pending_completions = fake_model.pending_completions();
1255 assert_eq!(pending_completions.len(), 1);
1256 let completion = pending_completions.pop().unwrap();
1257 let tool_names: Vec<String> = completion
1258 .tools
1259 .iter()
1260 .map(|tool| tool.name.clone())
1261 .collect();
1262 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
1263 fake_model.end_last_completion_stream();
1264
1265 // Switch to test-2 profile, and verify that it has only the infinite tool.
1266 thread
1267 .update(cx, |thread, cx| {
1268 thread.set_profile(AgentProfileId("test-2".into()), cx);
1269 thread.send(UserMessageId::new(), ["test2"], cx)
1270 })
1271 .unwrap();
1272 cx.run_until_parked();
1273 let mut pending_completions = fake_model.pending_completions();
1274 assert_eq!(pending_completions.len(), 1);
1275 let completion = pending_completions.pop().unwrap();
1276 let tool_names: Vec<String> = completion
1277 .tools
1278 .iter()
1279 .map(|tool| tool.name.clone())
1280 .collect();
1281 assert_eq!(tool_names, vec![InfiniteTool::name()]);
1282}
1283
1284#[gpui::test]
1285async fn test_mcp_tools(cx: &mut TestAppContext) {
1286 let ThreadTest {
1287 model,
1288 thread,
1289 context_server_store,
1290 fs,
1291 ..
1292 } = setup(cx, TestModel::Fake).await;
1293 let fake_model = model.as_fake();
1294
1295 // Override profiles and wait for settings to be loaded.
1296 fs.insert_file(
1297 paths::settings_file(),
1298 json!({
1299 "agent": {
1300 "always_allow_tool_actions": true,
1301 "profiles": {
1302 "test": {
1303 "name": "Test Profile",
1304 "enable_all_context_servers": true,
1305 "tools": {
1306 EchoTool::name(): true,
1307 }
1308 },
1309 }
1310 }
1311 })
1312 .to_string()
1313 .into_bytes(),
1314 )
1315 .await;
1316 cx.run_until_parked();
1317 thread.update(cx, |thread, cx| {
1318 thread.set_profile(AgentProfileId("test".into()), cx)
1319 });
1320
1321 let mut mcp_tool_calls = setup_context_server(
1322 "test_server",
1323 vec![context_server::types::Tool {
1324 name: "echo".into(),
1325 description: None,
1326 input_schema: serde_json::to_value(EchoTool::input_schema(
1327 LanguageModelToolSchemaFormat::JsonSchema,
1328 ))
1329 .unwrap(),
1330 output_schema: None,
1331 annotations: None,
1332 }],
1333 &context_server_store,
1334 cx,
1335 );
1336
1337 let events = thread.update(cx, |thread, cx| {
1338 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1339 });
1340 cx.run_until_parked();
1341
1342 // Simulate the model calling the MCP tool.
1343 let completion = fake_model.pending_completions().pop().unwrap();
1344 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1345 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1346 LanguageModelToolUse {
1347 id: "tool_1".into(),
1348 name: "echo".into(),
1349 raw_input: json!({"text": "test"}).to_string(),
1350 input: json!({"text": "test"}),
1351 is_input_complete: true,
1352 thought_signature: None,
1353 },
1354 ));
1355 fake_model.end_last_completion_stream();
1356 cx.run_until_parked();
1357
1358 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1359 assert_eq!(tool_call_params.name, "echo");
1360 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1361 tool_call_response
1362 .send(context_server::types::CallToolResponse {
1363 content: vec![context_server::types::ToolResponseContent::Text {
1364 text: "test".into(),
1365 }],
1366 is_error: None,
1367 meta: None,
1368 structured_content: None,
1369 })
1370 .unwrap();
1371 cx.run_until_parked();
1372
1373 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1374 fake_model.send_last_completion_stream_text_chunk("Done!");
1375 fake_model.end_last_completion_stream();
1376 events.collect::<Vec<_>>().await;
1377
1378 // Send again after adding the echo tool, ensuring the name collision is resolved.
1379 let events = thread.update(cx, |thread, cx| {
1380 thread.add_tool(EchoTool);
1381 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1382 });
1383 cx.run_until_parked();
1384 let completion = fake_model.pending_completions().pop().unwrap();
1385 assert_eq!(
1386 tool_names_for_completion(&completion),
1387 vec!["echo", "test_server_echo"]
1388 );
1389 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1390 LanguageModelToolUse {
1391 id: "tool_2".into(),
1392 name: "test_server_echo".into(),
1393 raw_input: json!({"text": "mcp"}).to_string(),
1394 input: json!({"text": "mcp"}),
1395 is_input_complete: true,
1396 thought_signature: None,
1397 },
1398 ));
1399 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1400 LanguageModelToolUse {
1401 id: "tool_3".into(),
1402 name: "echo".into(),
1403 raw_input: json!({"text": "native"}).to_string(),
1404 input: json!({"text": "native"}),
1405 is_input_complete: true,
1406 thought_signature: None,
1407 },
1408 ));
1409 fake_model.end_last_completion_stream();
1410 cx.run_until_parked();
1411
1412 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1413 assert_eq!(tool_call_params.name, "echo");
1414 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1415 tool_call_response
1416 .send(context_server::types::CallToolResponse {
1417 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1418 is_error: None,
1419 meta: None,
1420 structured_content: None,
1421 })
1422 .unwrap();
1423 cx.run_until_parked();
1424
1425 // Ensure the tool results were inserted with the correct names.
1426 let completion = fake_model.pending_completions().pop().unwrap();
1427 assert_eq!(
1428 completion.messages.last().unwrap().content,
1429 vec![
1430 MessageContent::ToolResult(LanguageModelToolResult {
1431 tool_use_id: "tool_3".into(),
1432 tool_name: "echo".into(),
1433 is_error: false,
1434 content: "native".into(),
1435 output: Some("native".into()),
1436 },),
1437 MessageContent::ToolResult(LanguageModelToolResult {
1438 tool_use_id: "tool_2".into(),
1439 tool_name: "test_server_echo".into(),
1440 is_error: false,
1441 content: "mcp".into(),
1442 output: Some("mcp".into()),
1443 },),
1444 ]
1445 );
1446 fake_model.end_last_completion_stream();
1447 events.collect::<Vec<_>>().await;
1448}
1449
1450#[gpui::test]
1451async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1452 let ThreadTest {
1453 model,
1454 thread,
1455 context_server_store,
1456 fs,
1457 ..
1458 } = setup(cx, TestModel::Fake).await;
1459 let fake_model = model.as_fake();
1460
1461 // Set up a profile with all tools enabled
1462 fs.insert_file(
1463 paths::settings_file(),
1464 json!({
1465 "agent": {
1466 "profiles": {
1467 "test": {
1468 "name": "Test Profile",
1469 "enable_all_context_servers": true,
1470 "tools": {
1471 EchoTool::name(): true,
1472 DelayTool::name(): true,
1473 WordListTool::name(): true,
1474 ToolRequiringPermission::name(): true,
1475 InfiniteTool::name(): true,
1476 }
1477 },
1478 }
1479 }
1480 })
1481 .to_string()
1482 .into_bytes(),
1483 )
1484 .await;
1485 cx.run_until_parked();
1486
1487 thread.update(cx, |thread, cx| {
1488 thread.set_profile(AgentProfileId("test".into()), cx);
1489 thread.add_tool(EchoTool);
1490 thread.add_tool(DelayTool);
1491 thread.add_tool(WordListTool);
1492 thread.add_tool(ToolRequiringPermission);
1493 thread.add_tool(InfiniteTool);
1494 });
1495
1496 // Set up multiple context servers with some overlapping tool names
1497 let _server1_calls = setup_context_server(
1498 "xxx",
1499 vec![
1500 context_server::types::Tool {
1501 name: "echo".into(), // Conflicts with native EchoTool
1502 description: None,
1503 input_schema: serde_json::to_value(EchoTool::input_schema(
1504 LanguageModelToolSchemaFormat::JsonSchema,
1505 ))
1506 .unwrap(),
1507 output_schema: None,
1508 annotations: None,
1509 },
1510 context_server::types::Tool {
1511 name: "unique_tool_1".into(),
1512 description: None,
1513 input_schema: json!({"type": "object", "properties": {}}),
1514 output_schema: None,
1515 annotations: None,
1516 },
1517 ],
1518 &context_server_store,
1519 cx,
1520 );
1521
1522 let _server2_calls = setup_context_server(
1523 "yyy",
1524 vec![
1525 context_server::types::Tool {
1526 name: "echo".into(), // Also conflicts with native EchoTool
1527 description: None,
1528 input_schema: serde_json::to_value(EchoTool::input_schema(
1529 LanguageModelToolSchemaFormat::JsonSchema,
1530 ))
1531 .unwrap(),
1532 output_schema: None,
1533 annotations: None,
1534 },
1535 context_server::types::Tool {
1536 name: "unique_tool_2".into(),
1537 description: None,
1538 input_schema: json!({"type": "object", "properties": {}}),
1539 output_schema: None,
1540 annotations: None,
1541 },
1542 context_server::types::Tool {
1543 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1544 description: None,
1545 input_schema: json!({"type": "object", "properties": {}}),
1546 output_schema: None,
1547 annotations: None,
1548 },
1549 context_server::types::Tool {
1550 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1551 description: None,
1552 input_schema: json!({"type": "object", "properties": {}}),
1553 output_schema: None,
1554 annotations: None,
1555 },
1556 ],
1557 &context_server_store,
1558 cx,
1559 );
1560 let _server3_calls = setup_context_server(
1561 "zzz",
1562 vec![
1563 context_server::types::Tool {
1564 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1565 description: None,
1566 input_schema: json!({"type": "object", "properties": {}}),
1567 output_schema: None,
1568 annotations: None,
1569 },
1570 context_server::types::Tool {
1571 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1572 description: None,
1573 input_schema: json!({"type": "object", "properties": {}}),
1574 output_schema: None,
1575 annotations: None,
1576 },
1577 context_server::types::Tool {
1578 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1579 description: None,
1580 input_schema: json!({"type": "object", "properties": {}}),
1581 output_schema: None,
1582 annotations: None,
1583 },
1584 ],
1585 &context_server_store,
1586 cx,
1587 );
1588
1589 thread
1590 .update(cx, |thread, cx| {
1591 thread.send(UserMessageId::new(), ["Go"], cx)
1592 })
1593 .unwrap();
1594 cx.run_until_parked();
1595 let completion = fake_model.pending_completions().pop().unwrap();
1596 assert_eq!(
1597 tool_names_for_completion(&completion),
1598 vec![
1599 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1600 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1601 "delay",
1602 "echo",
1603 "infinite",
1604 "tool_requiring_permission",
1605 "unique_tool_1",
1606 "unique_tool_2",
1607 "word_list",
1608 "xxx_echo",
1609 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1610 "yyy_echo",
1611 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1612 ]
1613 );
1614}
1615
1616#[gpui::test]
1617#[cfg_attr(not(feature = "e2e"), ignore)]
1618async fn test_cancellation(cx: &mut TestAppContext) {
1619 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1620
1621 let mut events = thread
1622 .update(cx, |thread, cx| {
1623 thread.add_tool(InfiniteTool);
1624 thread.add_tool(EchoTool);
1625 thread.send(
1626 UserMessageId::new(),
1627 ["Call the echo tool, then call the infinite tool, then explain their output"],
1628 cx,
1629 )
1630 })
1631 .unwrap();
1632
1633 // Wait until both tools are called.
1634 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1635 let mut echo_id = None;
1636 let mut echo_completed = false;
1637 while let Some(event) = events.next().await {
1638 match event.unwrap() {
1639 ThreadEvent::ToolCall(tool_call) => {
1640 assert_eq!(tool_call.title, expected_tools.remove(0));
1641 if tool_call.title == "Echo" {
1642 echo_id = Some(tool_call.tool_call_id);
1643 }
1644 }
1645 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1646 acp::ToolCallUpdate {
1647 tool_call_id,
1648 fields:
1649 acp::ToolCallUpdateFields {
1650 status: Some(acp::ToolCallStatus::Completed),
1651 ..
1652 },
1653 ..
1654 },
1655 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1656 echo_completed = true;
1657 }
1658 _ => {}
1659 }
1660
1661 if expected_tools.is_empty() && echo_completed {
1662 break;
1663 }
1664 }
1665
1666 // Cancel the current send and ensure that the event stream is closed, even
1667 // if one of the tools is still running.
1668 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1669 let events = events.collect::<Vec<_>>().await;
1670 let last_event = events.last();
1671 assert!(
1672 matches!(
1673 last_event,
1674 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1675 ),
1676 "unexpected event {last_event:?}"
1677 );
1678
1679 // Ensure we can still send a new message after cancellation.
1680 let events = thread
1681 .update(cx, |thread, cx| {
1682 thread.send(
1683 UserMessageId::new(),
1684 ["Testing: reply with 'Hello' then stop."],
1685 cx,
1686 )
1687 })
1688 .unwrap()
1689 .collect::<Vec<_>>()
1690 .await;
1691 thread.update(cx, |thread, _cx| {
1692 let message = thread.last_message().unwrap();
1693 let agent_message = message.as_agent_message().unwrap();
1694 assert_eq!(
1695 agent_message.content,
1696 vec![AgentMessageContent::Text("Hello".to_string())]
1697 );
1698 });
1699 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1700}
1701
1702#[gpui::test]
1703async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1704 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1705 always_allow_tools(cx);
1706 let fake_model = model.as_fake();
1707
1708 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1709 let environment = Rc::new(FakeThreadEnvironment {
1710 handle: handle.clone(),
1711 });
1712
1713 let mut events = thread
1714 .update(cx, |thread, cx| {
1715 thread.add_tool(crate::TerminalTool::new(
1716 thread.project().clone(),
1717 environment,
1718 ));
1719 thread.send(UserMessageId::new(), ["run a command"], cx)
1720 })
1721 .unwrap();
1722
1723 cx.run_until_parked();
1724
1725 // Simulate the model calling the terminal tool
1726 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1727 LanguageModelToolUse {
1728 id: "terminal_tool_1".into(),
1729 name: "terminal".into(),
1730 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1731 input: json!({"command": "sleep 1000", "cd": "."}),
1732 is_input_complete: true,
1733 thought_signature: None,
1734 },
1735 ));
1736 fake_model.end_last_completion_stream();
1737
1738 // Wait for the terminal tool to start running
1739 wait_for_terminal_tool_started(&mut events, cx).await;
1740
1741 // Cancel the thread while the terminal is running
1742 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1743
1744 // Collect remaining events, driving the executor to let cancellation complete
1745 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1746
1747 // Verify the terminal was killed
1748 assert!(
1749 handle.was_killed(),
1750 "expected terminal handle to be killed on cancellation"
1751 );
1752
1753 // Verify we got a cancellation stop event
1754 assert_eq!(
1755 stop_events(remaining_events),
1756 vec![acp::StopReason::Cancelled],
1757 );
1758
1759 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1760 thread.update(cx, |thread, _cx| {
1761 let message = thread.last_message().unwrap();
1762 let agent_message = message.as_agent_message().unwrap();
1763
1764 let tool_use = agent_message
1765 .content
1766 .iter()
1767 .find_map(|content| match content {
1768 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1769 _ => None,
1770 })
1771 .expect("expected tool use in agent message");
1772
1773 let tool_result = agent_message
1774 .tool_results
1775 .get(&tool_use.id)
1776 .expect("expected tool result");
1777
1778 let result_text = match &tool_result.content {
1779 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
1780 _ => panic!("expected text content in tool result"),
1781 };
1782
1783 // "partial output" comes from FakeTerminalHandle's output field
1784 assert!(
1785 result_text.contains("partial output"),
1786 "expected tool result to contain terminal output, got: {result_text}"
1787 );
1788 // Match the actual format from process_content in terminal_tool.rs
1789 assert!(
1790 result_text.contains("The user stopped this command"),
1791 "expected tool result to indicate user stopped, got: {result_text}"
1792 );
1793 });
1794
1795 // Verify we can send a new message after cancellation
1796 verify_thread_recovery(&thread, &fake_model, cx).await;
1797}
1798
1799#[gpui::test]
1800async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
1801 // This test verifies that tools which properly handle cancellation via
1802 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
1803 // to cancellation and report that they were cancelled.
1804 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1805 always_allow_tools(cx);
1806 let fake_model = model.as_fake();
1807
1808 let (tool, was_cancelled) = CancellationAwareTool::new();
1809
1810 let mut events = thread
1811 .update(cx, |thread, cx| {
1812 thread.add_tool(tool);
1813 thread.send(
1814 UserMessageId::new(),
1815 ["call the cancellation aware tool"],
1816 cx,
1817 )
1818 })
1819 .unwrap();
1820
1821 cx.run_until_parked();
1822
1823 // Simulate the model calling the cancellation-aware tool
1824 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1825 LanguageModelToolUse {
1826 id: "cancellation_aware_1".into(),
1827 name: "cancellation_aware".into(),
1828 raw_input: r#"{}"#.into(),
1829 input: json!({}),
1830 is_input_complete: true,
1831 thought_signature: None,
1832 },
1833 ));
1834 fake_model.end_last_completion_stream();
1835
1836 cx.run_until_parked();
1837
1838 // Wait for the tool call to be reported
1839 let mut tool_started = false;
1840 let deadline = cx.executor().num_cpus() * 100;
1841 for _ in 0..deadline {
1842 cx.run_until_parked();
1843
1844 while let Some(Some(event)) = events.next().now_or_never() {
1845 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
1846 if tool_call.title == "Cancellation Aware Tool" {
1847 tool_started = true;
1848 break;
1849 }
1850 }
1851 }
1852
1853 if tool_started {
1854 break;
1855 }
1856
1857 cx.background_executor
1858 .timer(Duration::from_millis(10))
1859 .await;
1860 }
1861 assert!(tool_started, "expected cancellation aware tool to start");
1862
1863 // Cancel the thread and wait for it to complete
1864 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
1865
1866 // The cancel task should complete promptly because the tool handles cancellation
1867 let timeout = cx.background_executor.timer(Duration::from_secs(5));
1868 futures::select! {
1869 _ = cancel_task.fuse() => {}
1870 _ = timeout.fuse() => {
1871 panic!("cancel task timed out - tool did not respond to cancellation");
1872 }
1873 }
1874
1875 // Verify the tool detected cancellation via its flag
1876 assert!(
1877 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
1878 "tool should have detected cancellation via event_stream.cancelled_by_user()"
1879 );
1880
1881 // Collect remaining events
1882 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1883
1884 // Verify we got a cancellation stop event
1885 assert_eq!(
1886 stop_events(remaining_events),
1887 vec![acp::StopReason::Cancelled],
1888 );
1889
1890 // Verify we can send a new message after cancellation
1891 verify_thread_recovery(&thread, &fake_model, cx).await;
1892}
1893
1894/// Helper to verify thread can recover after cancellation by sending a simple message.
1895async fn verify_thread_recovery(
1896 thread: &Entity<Thread>,
1897 fake_model: &FakeLanguageModel,
1898 cx: &mut TestAppContext,
1899) {
1900 let events = thread
1901 .update(cx, |thread, cx| {
1902 thread.send(
1903 UserMessageId::new(),
1904 ["Testing: reply with 'Hello' then stop."],
1905 cx,
1906 )
1907 })
1908 .unwrap();
1909 cx.run_until_parked();
1910 fake_model.send_last_completion_stream_text_chunk("Hello");
1911 fake_model
1912 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1913 fake_model.end_last_completion_stream();
1914
1915 let events = events.collect::<Vec<_>>().await;
1916 thread.update(cx, |thread, _cx| {
1917 let message = thread.last_message().unwrap();
1918 let agent_message = message.as_agent_message().unwrap();
1919 assert_eq!(
1920 agent_message.content,
1921 vec![AgentMessageContent::Text("Hello".to_string())]
1922 );
1923 });
1924 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1925}
1926
1927/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
1928async fn wait_for_terminal_tool_started(
1929 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1930 cx: &mut TestAppContext,
1931) {
1932 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
1933 for _ in 0..deadline {
1934 cx.run_until_parked();
1935
1936 while let Some(Some(event)) = events.next().now_or_never() {
1937 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1938 update,
1939 ))) = &event
1940 {
1941 if update.fields.content.as_ref().is_some_and(|content| {
1942 content
1943 .iter()
1944 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1945 }) {
1946 return;
1947 }
1948 }
1949 }
1950
1951 cx.background_executor
1952 .timer(Duration::from_millis(10))
1953 .await;
1954 }
1955 panic!("terminal tool did not start within the expected time");
1956}
1957
1958/// Collects events until a Stop event is received, driving the executor to completion.
1959async fn collect_events_until_stop(
1960 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1961 cx: &mut TestAppContext,
1962) -> Vec<Result<ThreadEvent>> {
1963 let mut collected = Vec::new();
1964 let deadline = cx.executor().num_cpus() * 200;
1965
1966 for _ in 0..deadline {
1967 cx.executor().advance_clock(Duration::from_millis(10));
1968 cx.run_until_parked();
1969
1970 while let Some(Some(event)) = events.next().now_or_never() {
1971 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
1972 collected.push(event);
1973 if is_stop {
1974 return collected;
1975 }
1976 }
1977 }
1978 panic!(
1979 "did not receive Stop event within the expected time; collected {} events",
1980 collected.len()
1981 );
1982}
1983
1984#[gpui::test]
1985async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
1986 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1987 always_allow_tools(cx);
1988 let fake_model = model.as_fake();
1989
1990 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1991 let environment = Rc::new(FakeThreadEnvironment {
1992 handle: handle.clone(),
1993 });
1994
1995 let message_id = UserMessageId::new();
1996 let mut events = thread
1997 .update(cx, |thread, cx| {
1998 thread.add_tool(crate::TerminalTool::new(
1999 thread.project().clone(),
2000 environment,
2001 ));
2002 thread.send(message_id.clone(), ["run a command"], cx)
2003 })
2004 .unwrap();
2005
2006 cx.run_until_parked();
2007
2008 // Simulate the model calling the terminal tool
2009 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2010 LanguageModelToolUse {
2011 id: "terminal_tool_1".into(),
2012 name: "terminal".into(),
2013 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2014 input: json!({"command": "sleep 1000", "cd": "."}),
2015 is_input_complete: true,
2016 thought_signature: None,
2017 },
2018 ));
2019 fake_model.end_last_completion_stream();
2020
2021 // Wait for the terminal tool to start running
2022 wait_for_terminal_tool_started(&mut events, cx).await;
2023
2024 // Truncate the thread while the terminal is running
2025 thread
2026 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2027 .unwrap();
2028
2029 // Drive the executor to let cancellation complete
2030 let _ = collect_events_until_stop(&mut events, cx).await;
2031
2032 // Verify the terminal was killed
2033 assert!(
2034 handle.was_killed(),
2035 "expected terminal handle to be killed on truncate"
2036 );
2037
2038 // Verify the thread is empty after truncation
2039 thread.update(cx, |thread, _cx| {
2040 assert_eq!(
2041 thread.to_markdown(),
2042 "",
2043 "expected thread to be empty after truncating the only message"
2044 );
2045 });
2046
2047 // Verify we can send a new message after truncation
2048 verify_thread_recovery(&thread, &fake_model, cx).await;
2049}
2050
2051#[gpui::test]
2052async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2053 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2054 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2055 always_allow_tools(cx);
2056 let fake_model = model.as_fake();
2057
2058 let environment = Rc::new(MultiTerminalEnvironment::new());
2059
2060 let mut events = thread
2061 .update(cx, |thread, cx| {
2062 thread.add_tool(crate::TerminalTool::new(
2063 thread.project().clone(),
2064 environment.clone(),
2065 ));
2066 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2067 })
2068 .unwrap();
2069
2070 cx.run_until_parked();
2071
2072 // Simulate the model calling two terminal tools
2073 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2074 LanguageModelToolUse {
2075 id: "terminal_tool_1".into(),
2076 name: "terminal".into(),
2077 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2078 input: json!({"command": "sleep 1000", "cd": "."}),
2079 is_input_complete: true,
2080 thought_signature: None,
2081 },
2082 ));
2083 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2084 LanguageModelToolUse {
2085 id: "terminal_tool_2".into(),
2086 name: "terminal".into(),
2087 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2088 input: json!({"command": "sleep 2000", "cd": "."}),
2089 is_input_complete: true,
2090 thought_signature: None,
2091 },
2092 ));
2093 fake_model.end_last_completion_stream();
2094
2095 // Wait for both terminal tools to start by counting terminal content updates
2096 let mut terminals_started = 0;
2097 let deadline = cx.executor().num_cpus() * 100;
2098 for _ in 0..deadline {
2099 cx.run_until_parked();
2100
2101 while let Some(Some(event)) = events.next().now_or_never() {
2102 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2103 update,
2104 ))) = &event
2105 {
2106 if update.fields.content.as_ref().is_some_and(|content| {
2107 content
2108 .iter()
2109 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2110 }) {
2111 terminals_started += 1;
2112 if terminals_started >= 2 {
2113 break;
2114 }
2115 }
2116 }
2117 }
2118 if terminals_started >= 2 {
2119 break;
2120 }
2121
2122 cx.background_executor
2123 .timer(Duration::from_millis(10))
2124 .await;
2125 }
2126 assert!(
2127 terminals_started >= 2,
2128 "expected 2 terminal tools to start, got {terminals_started}"
2129 );
2130
2131 // Cancel the thread while both terminals are running
2132 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2133
2134 // Collect remaining events
2135 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2136
2137 // Verify both terminal handles were killed
2138 let handles = environment.handles();
2139 assert_eq!(
2140 handles.len(),
2141 2,
2142 "expected 2 terminal handles to be created"
2143 );
2144 assert!(
2145 handles[0].was_killed(),
2146 "expected first terminal handle to be killed on cancellation"
2147 );
2148 assert!(
2149 handles[1].was_killed(),
2150 "expected second terminal handle to be killed on cancellation"
2151 );
2152
2153 // Verify we got a cancellation stop event
2154 assert_eq!(
2155 stop_events(remaining_events),
2156 vec![acp::StopReason::Cancelled],
2157 );
2158}
2159
2160#[gpui::test]
2161async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2162 // Tests that clicking the stop button on the terminal card (as opposed to the main
2163 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2164 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2165 always_allow_tools(cx);
2166 let fake_model = model.as_fake();
2167
2168 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2169 let environment = Rc::new(FakeThreadEnvironment {
2170 handle: handle.clone(),
2171 });
2172
2173 let mut events = thread
2174 .update(cx, |thread, cx| {
2175 thread.add_tool(crate::TerminalTool::new(
2176 thread.project().clone(),
2177 environment,
2178 ));
2179 thread.send(UserMessageId::new(), ["run a command"], cx)
2180 })
2181 .unwrap();
2182
2183 cx.run_until_parked();
2184
2185 // Simulate the model calling the terminal tool
2186 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2187 LanguageModelToolUse {
2188 id: "terminal_tool_1".into(),
2189 name: "terminal".into(),
2190 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2191 input: json!({"command": "sleep 1000", "cd": "."}),
2192 is_input_complete: true,
2193 thought_signature: None,
2194 },
2195 ));
2196 fake_model.end_last_completion_stream();
2197
2198 // Wait for the terminal tool to start running
2199 wait_for_terminal_tool_started(&mut events, cx).await;
2200
2201 // Simulate user clicking stop on the terminal card itself.
2202 // This sets the flag and signals exit (simulating what the real UI would do).
2203 handle.set_stopped_by_user(true);
2204 handle.killed.store(true, Ordering::SeqCst);
2205 handle.signal_exit();
2206
2207 // Wait for the tool to complete
2208 cx.run_until_parked();
2209
2210 // The thread continues after tool completion - simulate the model ending its turn
2211 fake_model
2212 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2213 fake_model.end_last_completion_stream();
2214
2215 // Collect remaining events
2216 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2217
2218 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2219 assert_eq!(
2220 stop_events(remaining_events),
2221 vec![acp::StopReason::EndTurn],
2222 );
2223
2224 // Verify the tool result indicates user stopped
2225 thread.update(cx, |thread, _cx| {
2226 let message = thread.last_message().unwrap();
2227 let agent_message = message.as_agent_message().unwrap();
2228
2229 let tool_use = agent_message
2230 .content
2231 .iter()
2232 .find_map(|content| match content {
2233 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2234 _ => None,
2235 })
2236 .expect("expected tool use in agent message");
2237
2238 let tool_result = agent_message
2239 .tool_results
2240 .get(&tool_use.id)
2241 .expect("expected tool result");
2242
2243 let result_text = match &tool_result.content {
2244 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2245 _ => panic!("expected text content in tool result"),
2246 };
2247
2248 assert!(
2249 result_text.contains("The user stopped this command"),
2250 "expected tool result to indicate user stopped, got: {result_text}"
2251 );
2252 });
2253}
2254
2255#[gpui::test]
2256async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2257 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2258 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2259 always_allow_tools(cx);
2260 let fake_model = model.as_fake();
2261
2262 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2263 let environment = Rc::new(FakeThreadEnvironment {
2264 handle: handle.clone(),
2265 });
2266
2267 let mut events = thread
2268 .update(cx, |thread, cx| {
2269 thread.add_tool(crate::TerminalTool::new(
2270 thread.project().clone(),
2271 environment,
2272 ));
2273 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2274 })
2275 .unwrap();
2276
2277 cx.run_until_parked();
2278
2279 // Simulate the model calling the terminal tool with a short timeout
2280 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2281 LanguageModelToolUse {
2282 id: "terminal_tool_1".into(),
2283 name: "terminal".into(),
2284 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2285 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2286 is_input_complete: true,
2287 thought_signature: None,
2288 },
2289 ));
2290 fake_model.end_last_completion_stream();
2291
2292 // Wait for the terminal tool to start running
2293 wait_for_terminal_tool_started(&mut events, cx).await;
2294
2295 // Advance clock past the timeout
2296 cx.executor().advance_clock(Duration::from_millis(200));
2297 cx.run_until_parked();
2298
2299 // The thread continues after tool completion - simulate the model ending its turn
2300 fake_model
2301 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2302 fake_model.end_last_completion_stream();
2303
2304 // Collect remaining events
2305 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2306
2307 // Verify the terminal was killed due to timeout
2308 assert!(
2309 handle.was_killed(),
2310 "expected terminal handle to be killed on timeout"
2311 );
2312
2313 // Verify we got an EndTurn (the tool completed, just with timeout)
2314 assert_eq!(
2315 stop_events(remaining_events),
2316 vec![acp::StopReason::EndTurn],
2317 );
2318
2319 // Verify the tool result indicates timeout, not user stopped
2320 thread.update(cx, |thread, _cx| {
2321 let message = thread.last_message().unwrap();
2322 let agent_message = message.as_agent_message().unwrap();
2323
2324 let tool_use = agent_message
2325 .content
2326 .iter()
2327 .find_map(|content| match content {
2328 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2329 _ => None,
2330 })
2331 .expect("expected tool use in agent message");
2332
2333 let tool_result = agent_message
2334 .tool_results
2335 .get(&tool_use.id)
2336 .expect("expected tool result");
2337
2338 let result_text = match &tool_result.content {
2339 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2340 _ => panic!("expected text content in tool result"),
2341 };
2342
2343 assert!(
2344 result_text.contains("timed out"),
2345 "expected tool result to indicate timeout, got: {result_text}"
2346 );
2347 assert!(
2348 !result_text.contains("The user stopped"),
2349 "tool result should not mention user stopped when it timed out, got: {result_text}"
2350 );
2351 });
2352}
2353
2354#[gpui::test]
2355async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2356 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2357 let fake_model = model.as_fake();
2358
2359 let events_1 = thread
2360 .update(cx, |thread, cx| {
2361 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2362 })
2363 .unwrap();
2364 cx.run_until_parked();
2365 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2366 cx.run_until_parked();
2367
2368 let events_2 = thread
2369 .update(cx, |thread, cx| {
2370 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2371 })
2372 .unwrap();
2373 cx.run_until_parked();
2374 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2375 fake_model
2376 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2377 fake_model.end_last_completion_stream();
2378
2379 let events_1 = events_1.collect::<Vec<_>>().await;
2380 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2381 let events_2 = events_2.collect::<Vec<_>>().await;
2382 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2383}
2384
2385#[gpui::test]
2386async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2387 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2388 let fake_model = model.as_fake();
2389
2390 let events_1 = thread
2391 .update(cx, |thread, cx| {
2392 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2393 })
2394 .unwrap();
2395 cx.run_until_parked();
2396 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2397 fake_model
2398 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2399 fake_model.end_last_completion_stream();
2400 let events_1 = events_1.collect::<Vec<_>>().await;
2401
2402 let events_2 = thread
2403 .update(cx, |thread, cx| {
2404 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2405 })
2406 .unwrap();
2407 cx.run_until_parked();
2408 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2409 fake_model
2410 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2411 fake_model.end_last_completion_stream();
2412 let events_2 = events_2.collect::<Vec<_>>().await;
2413
2414 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2415 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2416}
2417
2418#[gpui::test]
2419async fn test_refusal(cx: &mut TestAppContext) {
2420 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2421 let fake_model = model.as_fake();
2422
2423 let events = thread
2424 .update(cx, |thread, cx| {
2425 thread.send(UserMessageId::new(), ["Hello"], cx)
2426 })
2427 .unwrap();
2428 cx.run_until_parked();
2429 thread.read_with(cx, |thread, _| {
2430 assert_eq!(
2431 thread.to_markdown(),
2432 indoc! {"
2433 ## User
2434
2435 Hello
2436 "}
2437 );
2438 });
2439
2440 fake_model.send_last_completion_stream_text_chunk("Hey!");
2441 cx.run_until_parked();
2442 thread.read_with(cx, |thread, _| {
2443 assert_eq!(
2444 thread.to_markdown(),
2445 indoc! {"
2446 ## User
2447
2448 Hello
2449
2450 ## Assistant
2451
2452 Hey!
2453 "}
2454 );
2455 });
2456
2457 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2458 fake_model
2459 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2460 let events = events.collect::<Vec<_>>().await;
2461 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2462 thread.read_with(cx, |thread, _| {
2463 assert_eq!(thread.to_markdown(), "");
2464 });
2465}
2466
2467#[gpui::test]
2468async fn test_truncate_first_message(cx: &mut TestAppContext) {
2469 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2470 let fake_model = model.as_fake();
2471
2472 let message_id = UserMessageId::new();
2473 thread
2474 .update(cx, |thread, cx| {
2475 thread.send(message_id.clone(), ["Hello"], cx)
2476 })
2477 .unwrap();
2478 cx.run_until_parked();
2479 thread.read_with(cx, |thread, _| {
2480 assert_eq!(
2481 thread.to_markdown(),
2482 indoc! {"
2483 ## User
2484
2485 Hello
2486 "}
2487 );
2488 assert_eq!(thread.latest_token_usage(), None);
2489 });
2490
2491 fake_model.send_last_completion_stream_text_chunk("Hey!");
2492 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2493 language_model::TokenUsage {
2494 input_tokens: 32_000,
2495 output_tokens: 16_000,
2496 cache_creation_input_tokens: 0,
2497 cache_read_input_tokens: 0,
2498 },
2499 ));
2500 cx.run_until_parked();
2501 thread.read_with(cx, |thread, _| {
2502 assert_eq!(
2503 thread.to_markdown(),
2504 indoc! {"
2505 ## User
2506
2507 Hello
2508
2509 ## Assistant
2510
2511 Hey!
2512 "}
2513 );
2514 assert_eq!(
2515 thread.latest_token_usage(),
2516 Some(acp_thread::TokenUsage {
2517 used_tokens: 32_000 + 16_000,
2518 max_tokens: 1_000_000,
2519 input_tokens: 32_000,
2520 output_tokens: 16_000,
2521 })
2522 );
2523 });
2524
2525 thread
2526 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2527 .unwrap();
2528 cx.run_until_parked();
2529 thread.read_with(cx, |thread, _| {
2530 assert_eq!(thread.to_markdown(), "");
2531 assert_eq!(thread.latest_token_usage(), None);
2532 });
2533
2534 // Ensure we can still send a new message after truncation.
2535 thread
2536 .update(cx, |thread, cx| {
2537 thread.send(UserMessageId::new(), ["Hi"], cx)
2538 })
2539 .unwrap();
2540 thread.update(cx, |thread, _cx| {
2541 assert_eq!(
2542 thread.to_markdown(),
2543 indoc! {"
2544 ## User
2545
2546 Hi
2547 "}
2548 );
2549 });
2550 cx.run_until_parked();
2551 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2552 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2553 language_model::TokenUsage {
2554 input_tokens: 40_000,
2555 output_tokens: 20_000,
2556 cache_creation_input_tokens: 0,
2557 cache_read_input_tokens: 0,
2558 },
2559 ));
2560 cx.run_until_parked();
2561 thread.read_with(cx, |thread, _| {
2562 assert_eq!(
2563 thread.to_markdown(),
2564 indoc! {"
2565 ## User
2566
2567 Hi
2568
2569 ## Assistant
2570
2571 Ahoy!
2572 "}
2573 );
2574
2575 assert_eq!(
2576 thread.latest_token_usage(),
2577 Some(acp_thread::TokenUsage {
2578 used_tokens: 40_000 + 20_000,
2579 max_tokens: 1_000_000,
2580 input_tokens: 40_000,
2581 output_tokens: 20_000,
2582 })
2583 );
2584 });
2585}
2586
2587#[gpui::test]
2588async fn test_truncate_second_message(cx: &mut TestAppContext) {
2589 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2590 let fake_model = model.as_fake();
2591
2592 thread
2593 .update(cx, |thread, cx| {
2594 thread.send(UserMessageId::new(), ["Message 1"], cx)
2595 })
2596 .unwrap();
2597 cx.run_until_parked();
2598 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2599 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2600 language_model::TokenUsage {
2601 input_tokens: 32_000,
2602 output_tokens: 16_000,
2603 cache_creation_input_tokens: 0,
2604 cache_read_input_tokens: 0,
2605 },
2606 ));
2607 fake_model.end_last_completion_stream();
2608 cx.run_until_parked();
2609
2610 let assert_first_message_state = |cx: &mut TestAppContext| {
2611 thread.clone().read_with(cx, |thread, _| {
2612 assert_eq!(
2613 thread.to_markdown(),
2614 indoc! {"
2615 ## User
2616
2617 Message 1
2618
2619 ## Assistant
2620
2621 Message 1 response
2622 "}
2623 );
2624
2625 assert_eq!(
2626 thread.latest_token_usage(),
2627 Some(acp_thread::TokenUsage {
2628 used_tokens: 32_000 + 16_000,
2629 max_tokens: 1_000_000,
2630 input_tokens: 32_000,
2631 output_tokens: 16_000,
2632 })
2633 );
2634 });
2635 };
2636
2637 assert_first_message_state(cx);
2638
2639 let second_message_id = UserMessageId::new();
2640 thread
2641 .update(cx, |thread, cx| {
2642 thread.send(second_message_id.clone(), ["Message 2"], cx)
2643 })
2644 .unwrap();
2645 cx.run_until_parked();
2646
2647 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2648 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2649 language_model::TokenUsage {
2650 input_tokens: 40_000,
2651 output_tokens: 20_000,
2652 cache_creation_input_tokens: 0,
2653 cache_read_input_tokens: 0,
2654 },
2655 ));
2656 fake_model.end_last_completion_stream();
2657 cx.run_until_parked();
2658
2659 thread.read_with(cx, |thread, _| {
2660 assert_eq!(
2661 thread.to_markdown(),
2662 indoc! {"
2663 ## User
2664
2665 Message 1
2666
2667 ## Assistant
2668
2669 Message 1 response
2670
2671 ## User
2672
2673 Message 2
2674
2675 ## Assistant
2676
2677 Message 2 response
2678 "}
2679 );
2680
2681 assert_eq!(
2682 thread.latest_token_usage(),
2683 Some(acp_thread::TokenUsage {
2684 used_tokens: 40_000 + 20_000,
2685 max_tokens: 1_000_000,
2686 input_tokens: 40_000,
2687 output_tokens: 20_000,
2688 })
2689 );
2690 });
2691
2692 thread
2693 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2694 .unwrap();
2695 cx.run_until_parked();
2696
2697 assert_first_message_state(cx);
2698}
2699
2700#[gpui::test]
2701async fn test_title_generation(cx: &mut TestAppContext) {
2702 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2703 let fake_model = model.as_fake();
2704
2705 let summary_model = Arc::new(FakeLanguageModel::default());
2706 thread.update(cx, |thread, cx| {
2707 thread.set_summarization_model(Some(summary_model.clone()), cx)
2708 });
2709
2710 let send = thread
2711 .update(cx, |thread, cx| {
2712 thread.send(UserMessageId::new(), ["Hello"], cx)
2713 })
2714 .unwrap();
2715 cx.run_until_parked();
2716
2717 fake_model.send_last_completion_stream_text_chunk("Hey!");
2718 fake_model.end_last_completion_stream();
2719 cx.run_until_parked();
2720 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2721
2722 // Ensure the summary model has been invoked to generate a title.
2723 summary_model.send_last_completion_stream_text_chunk("Hello ");
2724 summary_model.send_last_completion_stream_text_chunk("world\nG");
2725 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2726 summary_model.end_last_completion_stream();
2727 send.collect::<Vec<_>>().await;
2728 cx.run_until_parked();
2729 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2730
2731 // Send another message, ensuring no title is generated this time.
2732 let send = thread
2733 .update(cx, |thread, cx| {
2734 thread.send(UserMessageId::new(), ["Hello again"], cx)
2735 })
2736 .unwrap();
2737 cx.run_until_parked();
2738 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2739 fake_model.end_last_completion_stream();
2740 cx.run_until_parked();
2741 assert_eq!(summary_model.pending_completions(), Vec::new());
2742 send.collect::<Vec<_>>().await;
2743 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2744}
2745
2746#[gpui::test]
2747async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2748 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2749 let fake_model = model.as_fake();
2750
2751 let _events = thread
2752 .update(cx, |thread, cx| {
2753 thread.add_tool(ToolRequiringPermission);
2754 thread.add_tool(EchoTool);
2755 thread.send(UserMessageId::new(), ["Hey!"], cx)
2756 })
2757 .unwrap();
2758 cx.run_until_parked();
2759
2760 let permission_tool_use = LanguageModelToolUse {
2761 id: "tool_id_1".into(),
2762 name: ToolRequiringPermission::name().into(),
2763 raw_input: "{}".into(),
2764 input: json!({}),
2765 is_input_complete: true,
2766 thought_signature: None,
2767 };
2768 let echo_tool_use = LanguageModelToolUse {
2769 id: "tool_id_2".into(),
2770 name: EchoTool::name().into(),
2771 raw_input: json!({"text": "test"}).to_string(),
2772 input: json!({"text": "test"}),
2773 is_input_complete: true,
2774 thought_signature: None,
2775 };
2776 fake_model.send_last_completion_stream_text_chunk("Hi!");
2777 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2778 permission_tool_use,
2779 ));
2780 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2781 echo_tool_use.clone(),
2782 ));
2783 fake_model.end_last_completion_stream();
2784 cx.run_until_parked();
2785
2786 // Ensure pending tools are skipped when building a request.
2787 let request = thread
2788 .read_with(cx, |thread, cx| {
2789 thread.build_completion_request(CompletionIntent::EditFile, cx)
2790 })
2791 .unwrap();
2792 assert_eq!(
2793 request.messages[1..],
2794 vec![
2795 LanguageModelRequestMessage {
2796 role: Role::User,
2797 content: vec!["Hey!".into()],
2798 cache: true,
2799 reasoning_details: None,
2800 },
2801 LanguageModelRequestMessage {
2802 role: Role::Assistant,
2803 content: vec![
2804 MessageContent::Text("Hi!".into()),
2805 MessageContent::ToolUse(echo_tool_use.clone())
2806 ],
2807 cache: false,
2808 reasoning_details: None,
2809 },
2810 LanguageModelRequestMessage {
2811 role: Role::User,
2812 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2813 tool_use_id: echo_tool_use.id.clone(),
2814 tool_name: echo_tool_use.name,
2815 is_error: false,
2816 content: "test".into(),
2817 output: Some("test".into())
2818 })],
2819 cache: false,
2820 reasoning_details: None,
2821 },
2822 ],
2823 );
2824}
2825
2826#[gpui::test]
2827async fn test_agent_connection(cx: &mut TestAppContext) {
2828 cx.update(settings::init);
2829 let templates = Templates::new();
2830
2831 // Initialize language model system with test provider
2832 cx.update(|cx| {
2833 gpui_tokio::init(cx);
2834
2835 let http_client = FakeHttpClient::with_404_response();
2836 let clock = Arc::new(clock::FakeSystemClock::new());
2837 let client = Client::new(clock, http_client, cx);
2838 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2839 language_model::init(client.clone(), cx);
2840 language_models::init(user_store, client.clone(), cx);
2841 LanguageModelRegistry::test(cx);
2842 });
2843 cx.executor().forbid_parking();
2844
2845 // Create a project for new_thread
2846 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2847 fake_fs.insert_tree(path!("/test"), json!({})).await;
2848 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2849 let cwd = Path::new("/test");
2850 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2851
2852 // Create agent and connection
2853 let agent = NativeAgent::new(
2854 project.clone(),
2855 thread_store,
2856 templates.clone(),
2857 None,
2858 fake_fs.clone(),
2859 &mut cx.to_async(),
2860 )
2861 .await
2862 .unwrap();
2863 let connection = NativeAgentConnection(agent.clone());
2864
2865 // Create a thread using new_thread
2866 let connection_rc = Rc::new(connection.clone());
2867 let acp_thread = cx
2868 .update(|cx| connection_rc.new_thread(project, cwd, cx))
2869 .await
2870 .expect("new_thread should succeed");
2871
2872 // Get the session_id from the AcpThread
2873 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2874
2875 // Test model_selector returns Some
2876 let selector_opt = connection.model_selector(&session_id);
2877 assert!(
2878 selector_opt.is_some(),
2879 "agent should always support ModelSelector"
2880 );
2881 let selector = selector_opt.unwrap();
2882
2883 // Test list_models
2884 let listed_models = cx
2885 .update(|cx| selector.list_models(cx))
2886 .await
2887 .expect("list_models should succeed");
2888 let AgentModelList::Grouped(listed_models) = listed_models else {
2889 panic!("Unexpected model list type");
2890 };
2891 assert!(!listed_models.is_empty(), "should have at least one model");
2892 assert_eq!(
2893 listed_models[&AgentModelGroupName("Fake".into())][0]
2894 .id
2895 .0
2896 .as_ref(),
2897 "fake/fake"
2898 );
2899
2900 // Test selected_model returns the default
2901 let model = cx
2902 .update(|cx| selector.selected_model(cx))
2903 .await
2904 .expect("selected_model should succeed");
2905 let model = cx
2906 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2907 .unwrap();
2908 let model = model.as_fake();
2909 assert_eq!(model.id().0, "fake", "should return default model");
2910
2911 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2912 cx.run_until_parked();
2913 model.send_last_completion_stream_text_chunk("def");
2914 cx.run_until_parked();
2915 acp_thread.read_with(cx, |thread, cx| {
2916 assert_eq!(
2917 thread.to_markdown(cx),
2918 indoc! {"
2919 ## User
2920
2921 abc
2922
2923 ## Assistant
2924
2925 def
2926
2927 "}
2928 )
2929 });
2930
2931 // Test cancel
2932 cx.update(|cx| connection.cancel(&session_id, cx));
2933 request.await.expect("prompt should fail gracefully");
2934
2935 // Ensure that dropping the ACP thread causes the native thread to be
2936 // dropped as well.
2937 cx.update(|_| drop(acp_thread));
2938 let result = cx
2939 .update(|cx| {
2940 connection.prompt(
2941 Some(acp_thread::UserMessageId::new()),
2942 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2943 cx,
2944 )
2945 })
2946 .await;
2947 assert_eq!(
2948 result.as_ref().unwrap_err().to_string(),
2949 "Session not found",
2950 "unexpected result: {:?}",
2951 result
2952 );
2953}
2954
2955#[gpui::test]
2956async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2957 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2958 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2959 let fake_model = model.as_fake();
2960
2961 let mut events = thread
2962 .update(cx, |thread, cx| {
2963 thread.send(UserMessageId::new(), ["Think"], cx)
2964 })
2965 .unwrap();
2966 cx.run_until_parked();
2967
2968 // Simulate streaming partial input.
2969 let input = json!({});
2970 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2971 LanguageModelToolUse {
2972 id: "1".into(),
2973 name: ThinkingTool::name().into(),
2974 raw_input: input.to_string(),
2975 input,
2976 is_input_complete: false,
2977 thought_signature: None,
2978 },
2979 ));
2980
2981 // Input streaming completed
2982 let input = json!({ "content": "Thinking hard!" });
2983 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2984 LanguageModelToolUse {
2985 id: "1".into(),
2986 name: "thinking".into(),
2987 raw_input: input.to_string(),
2988 input,
2989 is_input_complete: true,
2990 thought_signature: None,
2991 },
2992 ));
2993 fake_model.end_last_completion_stream();
2994 cx.run_until_parked();
2995
2996 let tool_call = expect_tool_call(&mut events).await;
2997 assert_eq!(
2998 tool_call,
2999 acp::ToolCall::new("1", "Thinking")
3000 .kind(acp::ToolKind::Think)
3001 .raw_input(json!({}))
3002 .meta(acp::Meta::from_iter([(
3003 "tool_name".into(),
3004 "thinking".into()
3005 )]))
3006 );
3007 let update = expect_tool_call_update_fields(&mut events).await;
3008 assert_eq!(
3009 update,
3010 acp::ToolCallUpdate::new(
3011 "1",
3012 acp::ToolCallUpdateFields::new()
3013 .title("Thinking")
3014 .kind(acp::ToolKind::Think)
3015 .raw_input(json!({ "content": "Thinking hard!"}))
3016 )
3017 );
3018 let update = expect_tool_call_update_fields(&mut events).await;
3019 assert_eq!(
3020 update,
3021 acp::ToolCallUpdate::new(
3022 "1",
3023 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3024 )
3025 );
3026 let update = expect_tool_call_update_fields(&mut events).await;
3027 assert_eq!(
3028 update,
3029 acp::ToolCallUpdate::new(
3030 "1",
3031 acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
3032 )
3033 );
3034 let update = expect_tool_call_update_fields(&mut events).await;
3035 assert_eq!(
3036 update,
3037 acp::ToolCallUpdate::new(
3038 "1",
3039 acp::ToolCallUpdateFields::new()
3040 .status(acp::ToolCallStatus::Completed)
3041 .raw_output("Finished thinking.")
3042 )
3043 );
3044}
3045
3046#[gpui::test]
3047async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3048 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3049 let fake_model = model.as_fake();
3050
3051 let mut events = thread
3052 .update(cx, |thread, cx| {
3053 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3054 thread.send(UserMessageId::new(), ["Hello!"], cx)
3055 })
3056 .unwrap();
3057 cx.run_until_parked();
3058
3059 fake_model.send_last_completion_stream_text_chunk("Hey!");
3060 fake_model.end_last_completion_stream();
3061
3062 let mut retry_events = Vec::new();
3063 while let Some(Ok(event)) = events.next().await {
3064 match event {
3065 ThreadEvent::Retry(retry_status) => {
3066 retry_events.push(retry_status);
3067 }
3068 ThreadEvent::Stop(..) => break,
3069 _ => {}
3070 }
3071 }
3072
3073 assert_eq!(retry_events.len(), 0);
3074 thread.read_with(cx, |thread, _cx| {
3075 assert_eq!(
3076 thread.to_markdown(),
3077 indoc! {"
3078 ## User
3079
3080 Hello!
3081
3082 ## Assistant
3083
3084 Hey!
3085 "}
3086 )
3087 });
3088}
3089
3090#[gpui::test]
3091async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3092 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3093 let fake_model = model.as_fake();
3094
3095 let mut events = thread
3096 .update(cx, |thread, cx| {
3097 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3098 thread.send(UserMessageId::new(), ["Hello!"], cx)
3099 })
3100 .unwrap();
3101 cx.run_until_parked();
3102
3103 fake_model.send_last_completion_stream_text_chunk("Hey,");
3104 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3105 provider: LanguageModelProviderName::new("Anthropic"),
3106 retry_after: Some(Duration::from_secs(3)),
3107 });
3108 fake_model.end_last_completion_stream();
3109
3110 cx.executor().advance_clock(Duration::from_secs(3));
3111 cx.run_until_parked();
3112
3113 fake_model.send_last_completion_stream_text_chunk("there!");
3114 fake_model.end_last_completion_stream();
3115 cx.run_until_parked();
3116
3117 let mut retry_events = Vec::new();
3118 while let Some(Ok(event)) = events.next().await {
3119 match event {
3120 ThreadEvent::Retry(retry_status) => {
3121 retry_events.push(retry_status);
3122 }
3123 ThreadEvent::Stop(..) => break,
3124 _ => {}
3125 }
3126 }
3127
3128 assert_eq!(retry_events.len(), 1);
3129 assert!(matches!(
3130 retry_events[0],
3131 acp_thread::RetryStatus { attempt: 1, .. }
3132 ));
3133 thread.read_with(cx, |thread, _cx| {
3134 assert_eq!(
3135 thread.to_markdown(),
3136 indoc! {"
3137 ## User
3138
3139 Hello!
3140
3141 ## Assistant
3142
3143 Hey,
3144
3145 [resume]
3146
3147 ## Assistant
3148
3149 there!
3150 "}
3151 )
3152 });
3153}
3154
3155#[gpui::test]
3156async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3157 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3158 let fake_model = model.as_fake();
3159
3160 let events = thread
3161 .update(cx, |thread, cx| {
3162 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3163 thread.add_tool(EchoTool);
3164 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3165 })
3166 .unwrap();
3167 cx.run_until_parked();
3168
3169 let tool_use_1 = LanguageModelToolUse {
3170 id: "tool_1".into(),
3171 name: EchoTool::name().into(),
3172 raw_input: json!({"text": "test"}).to_string(),
3173 input: json!({"text": "test"}),
3174 is_input_complete: true,
3175 thought_signature: None,
3176 };
3177 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3178 tool_use_1.clone(),
3179 ));
3180 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3181 provider: LanguageModelProviderName::new("Anthropic"),
3182 retry_after: Some(Duration::from_secs(3)),
3183 });
3184 fake_model.end_last_completion_stream();
3185
3186 cx.executor().advance_clock(Duration::from_secs(3));
3187 let completion = fake_model.pending_completions().pop().unwrap();
3188 assert_eq!(
3189 completion.messages[1..],
3190 vec![
3191 LanguageModelRequestMessage {
3192 role: Role::User,
3193 content: vec!["Call the echo tool!".into()],
3194 cache: false,
3195 reasoning_details: None,
3196 },
3197 LanguageModelRequestMessage {
3198 role: Role::Assistant,
3199 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3200 cache: false,
3201 reasoning_details: None,
3202 },
3203 LanguageModelRequestMessage {
3204 role: Role::User,
3205 content: vec![language_model::MessageContent::ToolResult(
3206 LanguageModelToolResult {
3207 tool_use_id: tool_use_1.id.clone(),
3208 tool_name: tool_use_1.name.clone(),
3209 is_error: false,
3210 content: "test".into(),
3211 output: Some("test".into())
3212 }
3213 )],
3214 cache: true,
3215 reasoning_details: None,
3216 },
3217 ]
3218 );
3219
3220 fake_model.send_last_completion_stream_text_chunk("Done");
3221 fake_model.end_last_completion_stream();
3222 cx.run_until_parked();
3223 events.collect::<Vec<_>>().await;
3224 thread.read_with(cx, |thread, _cx| {
3225 assert_eq!(
3226 thread.last_message(),
3227 Some(Message::Agent(AgentMessage {
3228 content: vec![AgentMessageContent::Text("Done".into())],
3229 tool_results: IndexMap::default(),
3230 reasoning_details: None,
3231 }))
3232 );
3233 })
3234}
3235
3236#[gpui::test]
3237async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3238 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3239 let fake_model = model.as_fake();
3240
3241 let mut events = thread
3242 .update(cx, |thread, cx| {
3243 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3244 thread.send(UserMessageId::new(), ["Hello!"], cx)
3245 })
3246 .unwrap();
3247 cx.run_until_parked();
3248
3249 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3250 fake_model.send_last_completion_stream_error(
3251 LanguageModelCompletionError::ServerOverloaded {
3252 provider: LanguageModelProviderName::new("Anthropic"),
3253 retry_after: Some(Duration::from_secs(3)),
3254 },
3255 );
3256 fake_model.end_last_completion_stream();
3257 cx.executor().advance_clock(Duration::from_secs(3));
3258 cx.run_until_parked();
3259 }
3260
3261 let mut errors = Vec::new();
3262 let mut retry_events = Vec::new();
3263 while let Some(event) = events.next().await {
3264 match event {
3265 Ok(ThreadEvent::Retry(retry_status)) => {
3266 retry_events.push(retry_status);
3267 }
3268 Ok(ThreadEvent::Stop(..)) => break,
3269 Err(error) => errors.push(error),
3270 _ => {}
3271 }
3272 }
3273
3274 assert_eq!(
3275 retry_events.len(),
3276 crate::thread::MAX_RETRY_ATTEMPTS as usize
3277 );
3278 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3279 assert_eq!(retry_events[i].attempt, i + 1);
3280 }
3281 assert_eq!(errors.len(), 1);
3282 let error = errors[0]
3283 .downcast_ref::<LanguageModelCompletionError>()
3284 .unwrap();
3285 assert!(matches!(
3286 error,
3287 LanguageModelCompletionError::ServerOverloaded { .. }
3288 ));
3289}
3290
3291/// Filters out the stop events for asserting against in tests
3292fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3293 result_events
3294 .into_iter()
3295 .filter_map(|event| match event.unwrap() {
3296 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3297 _ => None,
3298 })
3299 .collect()
3300}
3301
3302struct ThreadTest {
3303 model: Arc<dyn LanguageModel>,
3304 thread: Entity<Thread>,
3305 project_context: Entity<ProjectContext>,
3306 context_server_store: Entity<ContextServerStore>,
3307 fs: Arc<FakeFs>,
3308}
3309
3310enum TestModel {
3311 Sonnet4,
3312 Fake,
3313}
3314
3315impl TestModel {
3316 fn id(&self) -> LanguageModelId {
3317 match self {
3318 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3319 TestModel::Fake => unreachable!(),
3320 }
3321 }
3322}
3323
3324async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3325 cx.executor().allow_parking();
3326
3327 let fs = FakeFs::new(cx.background_executor.clone());
3328 fs.create_dir(paths::settings_file().parent().unwrap())
3329 .await
3330 .unwrap();
3331 fs.insert_file(
3332 paths::settings_file(),
3333 json!({
3334 "agent": {
3335 "default_profile": "test-profile",
3336 "profiles": {
3337 "test-profile": {
3338 "name": "Test Profile",
3339 "tools": {
3340 EchoTool::name(): true,
3341 DelayTool::name(): true,
3342 WordListTool::name(): true,
3343 ToolRequiringPermission::name(): true,
3344 InfiniteTool::name(): true,
3345 CancellationAwareTool::name(): true,
3346 ThinkingTool::name(): true,
3347 "terminal": true,
3348 }
3349 }
3350 }
3351 }
3352 })
3353 .to_string()
3354 .into_bytes(),
3355 )
3356 .await;
3357
3358 cx.update(|cx| {
3359 settings::init(cx);
3360
3361 match model {
3362 TestModel::Fake => {}
3363 TestModel::Sonnet4 => {
3364 gpui_tokio::init(cx);
3365 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3366 cx.set_http_client(Arc::new(http_client));
3367 let client = Client::production(cx);
3368 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3369 language_model::init(client.clone(), cx);
3370 language_models::init(user_store, client.clone(), cx);
3371 }
3372 };
3373
3374 watch_settings(fs.clone(), cx);
3375 });
3376
3377 let templates = Templates::new();
3378
3379 fs.insert_tree(path!("/test"), json!({})).await;
3380 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3381
3382 let model = cx
3383 .update(|cx| {
3384 if let TestModel::Fake = model {
3385 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3386 } else {
3387 let model_id = model.id();
3388 let models = LanguageModelRegistry::read_global(cx);
3389 let model = models
3390 .available_models(cx)
3391 .find(|model| model.id() == model_id)
3392 .unwrap();
3393
3394 let provider = models.provider(&model.provider_id()).unwrap();
3395 let authenticated = provider.authenticate(cx);
3396
3397 cx.spawn(async move |_cx| {
3398 authenticated.await.unwrap();
3399 model
3400 })
3401 }
3402 })
3403 .await;
3404
3405 let project_context = cx.new(|_cx| ProjectContext::default());
3406 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3407 let context_server_registry =
3408 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3409 let thread = cx.new(|cx| {
3410 Thread::new(
3411 project,
3412 project_context.clone(),
3413 context_server_registry,
3414 templates,
3415 Some(model.clone()),
3416 cx,
3417 )
3418 });
3419 ThreadTest {
3420 model,
3421 thread,
3422 project_context,
3423 context_server_store,
3424 fs,
3425 }
3426}
3427
3428#[cfg(test)]
3429#[ctor::ctor]
3430fn init_logger() {
3431 if std::env::var("RUST_LOG").is_ok() {
3432 env_logger::init();
3433 }
3434}
3435
3436fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3437 let fs = fs.clone();
3438 cx.spawn({
3439 async move |cx| {
3440 let mut new_settings_content_rx = settings::watch_config_file(
3441 cx.background_executor(),
3442 fs,
3443 paths::settings_file().clone(),
3444 );
3445
3446 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3447 cx.update(|cx| {
3448 SettingsStore::update_global(cx, |settings, cx| {
3449 settings.set_user_settings(&new_settings_content, cx)
3450 })
3451 })
3452 .ok();
3453 }
3454 }
3455 })
3456 .detach();
3457}
3458
3459fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3460 completion
3461 .tools
3462 .iter()
3463 .map(|tool| tool.name.clone())
3464 .collect()
3465}
3466
3467fn setup_context_server(
3468 name: &'static str,
3469 tools: Vec<context_server::types::Tool>,
3470 context_server_store: &Entity<ContextServerStore>,
3471 cx: &mut TestAppContext,
3472) -> mpsc::UnboundedReceiver<(
3473 context_server::types::CallToolParams,
3474 oneshot::Sender<context_server::types::CallToolResponse>,
3475)> {
3476 cx.update(|cx| {
3477 let mut settings = ProjectSettings::get_global(cx).clone();
3478 settings.context_servers.insert(
3479 name.into(),
3480 project::project_settings::ContextServerSettings::Stdio {
3481 enabled: true,
3482 remote: false,
3483 command: ContextServerCommand {
3484 path: "somebinary".into(),
3485 args: Vec::new(),
3486 env: None,
3487 timeout: None,
3488 },
3489 },
3490 );
3491 ProjectSettings::override_global(settings, cx);
3492 });
3493
3494 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3495 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3496 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3497 context_server::types::InitializeResponse {
3498 protocol_version: context_server::types::ProtocolVersion(
3499 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3500 ),
3501 server_info: context_server::types::Implementation {
3502 name: name.into(),
3503 version: "1.0.0".to_string(),
3504 },
3505 capabilities: context_server::types::ServerCapabilities {
3506 tools: Some(context_server::types::ToolsCapabilities {
3507 list_changed: Some(true),
3508 }),
3509 ..Default::default()
3510 },
3511 meta: None,
3512 }
3513 })
3514 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3515 let tools = tools.clone();
3516 async move {
3517 context_server::types::ListToolsResponse {
3518 tools,
3519 next_cursor: None,
3520 meta: None,
3521 }
3522 }
3523 })
3524 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3525 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3526 async move {
3527 let (response_tx, response_rx) = oneshot::channel();
3528 mcp_tool_calls_tx
3529 .unbounded_send((params, response_tx))
3530 .unwrap();
3531 response_rx.await.unwrap()
3532 }
3533 });
3534 context_server_store.update(cx, |store, cx| {
3535 store.start_server(
3536 Arc::new(ContextServer::new(
3537 ContextServerId(name.into()),
3538 Arc::new(fake_transport),
3539 )),
3540 cx,
3541 );
3542 });
3543 cx.run_until_parked();
3544 mcp_tool_calls_rx
3545}
3546
3547#[gpui::test]
3548async fn test_tokens_before_message(cx: &mut TestAppContext) {
3549 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3550 let fake_model = model.as_fake();
3551
3552 // First message
3553 let message_1_id = UserMessageId::new();
3554 thread
3555 .update(cx, |thread, cx| {
3556 thread.send(message_1_id.clone(), ["First message"], cx)
3557 })
3558 .unwrap();
3559 cx.run_until_parked();
3560
3561 // Before any response, tokens_before_message should return None for first message
3562 thread.read_with(cx, |thread, _| {
3563 assert_eq!(
3564 thread.tokens_before_message(&message_1_id),
3565 None,
3566 "First message should have no tokens before it"
3567 );
3568 });
3569
3570 // Complete first message with usage
3571 fake_model.send_last_completion_stream_text_chunk("Response 1");
3572 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3573 language_model::TokenUsage {
3574 input_tokens: 100,
3575 output_tokens: 50,
3576 cache_creation_input_tokens: 0,
3577 cache_read_input_tokens: 0,
3578 },
3579 ));
3580 fake_model.end_last_completion_stream();
3581 cx.run_until_parked();
3582
3583 // First message still has no tokens before it
3584 thread.read_with(cx, |thread, _| {
3585 assert_eq!(
3586 thread.tokens_before_message(&message_1_id),
3587 None,
3588 "First message should still have no tokens before it after response"
3589 );
3590 });
3591
3592 // Second message
3593 let message_2_id = UserMessageId::new();
3594 thread
3595 .update(cx, |thread, cx| {
3596 thread.send(message_2_id.clone(), ["Second message"], cx)
3597 })
3598 .unwrap();
3599 cx.run_until_parked();
3600
3601 // Second message should have first message's input tokens before it
3602 thread.read_with(cx, |thread, _| {
3603 assert_eq!(
3604 thread.tokens_before_message(&message_2_id),
3605 Some(100),
3606 "Second message should have 100 tokens before it (from first request)"
3607 );
3608 });
3609
3610 // Complete second message
3611 fake_model.send_last_completion_stream_text_chunk("Response 2");
3612 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3613 language_model::TokenUsage {
3614 input_tokens: 250, // Total for this request (includes previous context)
3615 output_tokens: 75,
3616 cache_creation_input_tokens: 0,
3617 cache_read_input_tokens: 0,
3618 },
3619 ));
3620 fake_model.end_last_completion_stream();
3621 cx.run_until_parked();
3622
3623 // Third message
3624 let message_3_id = UserMessageId::new();
3625 thread
3626 .update(cx, |thread, cx| {
3627 thread.send(message_3_id.clone(), ["Third message"], cx)
3628 })
3629 .unwrap();
3630 cx.run_until_parked();
3631
3632 // Third message should have second message's input tokens (250) before it
3633 thread.read_with(cx, |thread, _| {
3634 assert_eq!(
3635 thread.tokens_before_message(&message_3_id),
3636 Some(250),
3637 "Third message should have 250 tokens before it (from second request)"
3638 );
3639 // Second message should still have 100
3640 assert_eq!(
3641 thread.tokens_before_message(&message_2_id),
3642 Some(100),
3643 "Second message should still have 100 tokens before it"
3644 );
3645 // First message still has none
3646 assert_eq!(
3647 thread.tokens_before_message(&message_1_id),
3648 None,
3649 "First message should still have no tokens before it"
3650 );
3651 });
3652}
3653
3654#[gpui::test]
3655async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3656 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3657 let fake_model = model.as_fake();
3658
3659 // Set up three messages with responses
3660 let message_1_id = UserMessageId::new();
3661 thread
3662 .update(cx, |thread, cx| {
3663 thread.send(message_1_id.clone(), ["Message 1"], cx)
3664 })
3665 .unwrap();
3666 cx.run_until_parked();
3667 fake_model.send_last_completion_stream_text_chunk("Response 1");
3668 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3669 language_model::TokenUsage {
3670 input_tokens: 100,
3671 output_tokens: 50,
3672 cache_creation_input_tokens: 0,
3673 cache_read_input_tokens: 0,
3674 },
3675 ));
3676 fake_model.end_last_completion_stream();
3677 cx.run_until_parked();
3678
3679 let message_2_id = UserMessageId::new();
3680 thread
3681 .update(cx, |thread, cx| {
3682 thread.send(message_2_id.clone(), ["Message 2"], cx)
3683 })
3684 .unwrap();
3685 cx.run_until_parked();
3686 fake_model.send_last_completion_stream_text_chunk("Response 2");
3687 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3688 language_model::TokenUsage {
3689 input_tokens: 250,
3690 output_tokens: 75,
3691 cache_creation_input_tokens: 0,
3692 cache_read_input_tokens: 0,
3693 },
3694 ));
3695 fake_model.end_last_completion_stream();
3696 cx.run_until_parked();
3697
3698 // Verify initial state
3699 thread.read_with(cx, |thread, _| {
3700 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3701 });
3702
3703 // Truncate at message 2 (removes message 2 and everything after)
3704 thread
3705 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3706 .unwrap();
3707 cx.run_until_parked();
3708
3709 // After truncation, message_2_id no longer exists, so lookup should return None
3710 thread.read_with(cx, |thread, _| {
3711 assert_eq!(
3712 thread.tokens_before_message(&message_2_id),
3713 None,
3714 "After truncation, message 2 no longer exists"
3715 );
3716 // Message 1 still exists but has no tokens before it
3717 assert_eq!(
3718 thread.tokens_before_message(&message_1_id),
3719 None,
3720 "First message still has no tokens before it"
3721 );
3722 });
3723}
3724
3725#[gpui::test]
3726async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3727 init_test(cx);
3728
3729 let fs = FakeFs::new(cx.executor());
3730 fs.insert_tree("/root", json!({})).await;
3731 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3732
3733 // Test 1: Deny rule blocks command
3734 {
3735 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3736 let environment = Rc::new(FakeThreadEnvironment {
3737 handle: handle.clone(),
3738 });
3739
3740 cx.update(|cx| {
3741 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3742 settings.tool_permissions.tools.insert(
3743 "terminal".into(),
3744 agent_settings::ToolRules {
3745 default_mode: settings::ToolPermissionMode::Confirm,
3746 always_allow: vec![],
3747 always_deny: vec![
3748 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3749 ],
3750 always_confirm: vec![],
3751 invalid_patterns: vec![],
3752 },
3753 );
3754 agent_settings::AgentSettings::override_global(settings, cx);
3755 });
3756
3757 #[allow(clippy::arc_with_non_send_sync)]
3758 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3759 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3760
3761 let task = cx.update(|cx| {
3762 tool.run(
3763 crate::TerminalToolInput {
3764 command: "rm -rf /".to_string(),
3765 cd: ".".to_string(),
3766 timeout_ms: None,
3767 },
3768 event_stream,
3769 cx,
3770 )
3771 });
3772
3773 let result = task.await;
3774 assert!(
3775 result.is_err(),
3776 "expected command to be blocked by deny rule"
3777 );
3778 assert!(
3779 result.unwrap_err().to_string().contains("blocked"),
3780 "error should mention the command was blocked"
3781 );
3782 }
3783
3784 // Test 2: Allow rule skips confirmation (and overrides default_mode: Deny)
3785 {
3786 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3787 let environment = Rc::new(FakeThreadEnvironment {
3788 handle: handle.clone(),
3789 });
3790
3791 cx.update(|cx| {
3792 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3793 settings.always_allow_tool_actions = false;
3794 settings.tool_permissions.tools.insert(
3795 "terminal".into(),
3796 agent_settings::ToolRules {
3797 default_mode: settings::ToolPermissionMode::Deny,
3798 always_allow: vec![
3799 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
3800 ],
3801 always_deny: vec![],
3802 always_confirm: vec![],
3803 invalid_patterns: vec![],
3804 },
3805 );
3806 agent_settings::AgentSettings::override_global(settings, cx);
3807 });
3808
3809 #[allow(clippy::arc_with_non_send_sync)]
3810 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3811 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3812
3813 let task = cx.update(|cx| {
3814 tool.run(
3815 crate::TerminalToolInput {
3816 command: "echo hello".to_string(),
3817 cd: ".".to_string(),
3818 timeout_ms: None,
3819 },
3820 event_stream,
3821 cx,
3822 )
3823 });
3824
3825 let update = rx.expect_update_fields().await;
3826 assert!(
3827 update.content.iter().any(|blocks| {
3828 blocks
3829 .iter()
3830 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
3831 }),
3832 "expected terminal content (allow rule should skip confirmation and override default deny)"
3833 );
3834
3835 let result = task.await;
3836 assert!(
3837 result.is_ok(),
3838 "expected command to succeed without confirmation"
3839 );
3840 }
3841
3842 // Test 3: Confirm rule forces confirmation even with always_allow_tool_actions=true
3843 {
3844 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3845 let environment = Rc::new(FakeThreadEnvironment {
3846 handle: handle.clone(),
3847 });
3848
3849 cx.update(|cx| {
3850 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3851 settings.always_allow_tool_actions = true;
3852 settings.tool_permissions.tools.insert(
3853 "terminal".into(),
3854 agent_settings::ToolRules {
3855 default_mode: settings::ToolPermissionMode::Allow,
3856 always_allow: vec![],
3857 always_deny: vec![],
3858 always_confirm: vec![
3859 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
3860 ],
3861 invalid_patterns: vec![],
3862 },
3863 );
3864 agent_settings::AgentSettings::override_global(settings, cx);
3865 });
3866
3867 #[allow(clippy::arc_with_non_send_sync)]
3868 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3869 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3870
3871 let _task = cx.update(|cx| {
3872 tool.run(
3873 crate::TerminalToolInput {
3874 command: "sudo rm file".to_string(),
3875 cd: ".".to_string(),
3876 timeout_ms: None,
3877 },
3878 event_stream,
3879 cx,
3880 )
3881 });
3882
3883 let auth = rx.expect_authorization().await;
3884 assert!(
3885 auth.tool_call.fields.title.is_some(),
3886 "expected authorization request for sudo command despite always_allow_tool_actions=true"
3887 );
3888 }
3889
3890 // Test 4: default_mode: Deny blocks commands when no pattern matches
3891 {
3892 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3893 let environment = Rc::new(FakeThreadEnvironment {
3894 handle: handle.clone(),
3895 });
3896
3897 cx.update(|cx| {
3898 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3899 settings.always_allow_tool_actions = true;
3900 settings.tool_permissions.tools.insert(
3901 "terminal".into(),
3902 agent_settings::ToolRules {
3903 default_mode: settings::ToolPermissionMode::Deny,
3904 always_allow: vec![],
3905 always_deny: vec![],
3906 always_confirm: vec![],
3907 invalid_patterns: vec![],
3908 },
3909 );
3910 agent_settings::AgentSettings::override_global(settings, cx);
3911 });
3912
3913 #[allow(clippy::arc_with_non_send_sync)]
3914 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3915 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3916
3917 let task = cx.update(|cx| {
3918 tool.run(
3919 crate::TerminalToolInput {
3920 command: "echo hello".to_string(),
3921 cd: ".".to_string(),
3922 timeout_ms: None,
3923 },
3924 event_stream,
3925 cx,
3926 )
3927 });
3928
3929 let result = task.await;
3930 assert!(
3931 result.is_err(),
3932 "expected command to be blocked by default_mode: Deny"
3933 );
3934 assert!(
3935 result.unwrap_err().to_string().contains("disabled"),
3936 "error should mention the tool is disabled"
3937 );
3938 }
3939}
3940
3941#[gpui::test]
3942async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
3943 init_test(cx);
3944
3945 cx.update(|cx| {
3946 cx.update_flags(true, vec!["subagents".to_string()]);
3947 });
3948
3949 let fs = FakeFs::new(cx.executor());
3950 fs.insert_tree(path!("/test"), json!({})).await;
3951 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3952 let project_context = cx.new(|_cx| ProjectContext::default());
3953 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3954 let context_server_registry =
3955 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3956 let model = Arc::new(FakeLanguageModel::default());
3957
3958 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3959 let environment = Rc::new(FakeThreadEnvironment { handle });
3960
3961 let thread = cx.new(|cx| {
3962 let mut thread = Thread::new(
3963 project.clone(),
3964 project_context,
3965 context_server_registry,
3966 Templates::new(),
3967 Some(model),
3968 cx,
3969 );
3970 thread.add_default_tools(environment, cx);
3971 thread
3972 });
3973
3974 thread.read_with(cx, |thread, _| {
3975 assert!(
3976 thread.has_registered_tool("subagent"),
3977 "subagent tool should be present when feature flag is enabled"
3978 );
3979 });
3980}
3981
3982#[gpui::test]
3983async fn test_subagent_thread_inherits_parent_model(cx: &mut TestAppContext) {
3984 init_test(cx);
3985
3986 cx.update(|cx| {
3987 cx.update_flags(true, vec!["subagents".to_string()]);
3988 });
3989
3990 let fs = FakeFs::new(cx.executor());
3991 fs.insert_tree(path!("/test"), json!({})).await;
3992 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3993 let project_context = cx.new(|_cx| ProjectContext::default());
3994 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3995 let context_server_registry =
3996 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3997 let model = Arc::new(FakeLanguageModel::default());
3998
3999 let subagent_context = SubagentContext {
4000 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4001 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4002 depth: 1,
4003 summary_prompt: "Summarize".to_string(),
4004 context_low_prompt: "Context low".to_string(),
4005 };
4006
4007 let subagent = cx.new(|cx| {
4008 Thread::new_subagent(
4009 project.clone(),
4010 project_context,
4011 context_server_registry,
4012 Templates::new(),
4013 model.clone(),
4014 subagent_context,
4015 std::collections::BTreeMap::new(),
4016 cx,
4017 )
4018 });
4019
4020 subagent.read_with(cx, |thread, _| {
4021 assert!(thread.is_subagent());
4022 assert_eq!(thread.depth(), 1);
4023 assert!(thread.model().is_some());
4024 });
4025}
4026
4027#[gpui::test]
4028async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4029 init_test(cx);
4030
4031 cx.update(|cx| {
4032 cx.update_flags(true, vec!["subagents".to_string()]);
4033 });
4034
4035 let fs = FakeFs::new(cx.executor());
4036 fs.insert_tree(path!("/test"), json!({})).await;
4037 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4038 let project_context = cx.new(|_cx| ProjectContext::default());
4039 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4040 let context_server_registry =
4041 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4042 let model = Arc::new(FakeLanguageModel::default());
4043
4044 let subagent_context = SubagentContext {
4045 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4046 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4047 depth: MAX_SUBAGENT_DEPTH,
4048 summary_prompt: "Summarize".to_string(),
4049 context_low_prompt: "Context low".to_string(),
4050 };
4051
4052 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
4053 let environment = Rc::new(FakeThreadEnvironment { handle });
4054
4055 let deep_subagent = cx.new(|cx| {
4056 let mut thread = Thread::new_subagent(
4057 project.clone(),
4058 project_context,
4059 context_server_registry,
4060 Templates::new(),
4061 model.clone(),
4062 subagent_context,
4063 std::collections::BTreeMap::new(),
4064 cx,
4065 );
4066 thread.add_default_tools(environment, cx);
4067 thread
4068 });
4069
4070 deep_subagent.read_with(cx, |thread, _| {
4071 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
4072 assert!(
4073 !thread.has_registered_tool("subagent"),
4074 "subagent tool should not be present at max depth"
4075 );
4076 });
4077}
4078
4079#[gpui::test]
4080async fn test_subagent_receives_task_prompt(cx: &mut TestAppContext) {
4081 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4082 let fake_model = model.as_fake();
4083
4084 cx.update(|cx| {
4085 cx.update_flags(true, vec!["subagents".to_string()]);
4086 });
4087
4088 let subagent_context = SubagentContext {
4089 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4090 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4091 depth: 1,
4092 summary_prompt: "Summarize your work".to_string(),
4093 context_low_prompt: "Context low, wrap up".to_string(),
4094 };
4095
4096 let project = thread.read_with(cx, |t, _| t.project.clone());
4097 let project_context = cx.new(|_cx| ProjectContext::default());
4098 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4099 let context_server_registry =
4100 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4101
4102 let subagent = cx.new(|cx| {
4103 Thread::new_subagent(
4104 project.clone(),
4105 project_context,
4106 context_server_registry,
4107 Templates::new(),
4108 model.clone(),
4109 subagent_context,
4110 std::collections::BTreeMap::new(),
4111 cx,
4112 )
4113 });
4114
4115 let task_prompt = "Find all TODO comments in the codebase";
4116 subagent
4117 .update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))
4118 .unwrap();
4119 cx.run_until_parked();
4120
4121 let pending = fake_model.pending_completions();
4122 assert_eq!(pending.len(), 1, "should have one pending completion");
4123
4124 let messages = &pending[0].messages;
4125 let user_messages: Vec<_> = messages
4126 .iter()
4127 .filter(|m| m.role == language_model::Role::User)
4128 .collect();
4129 assert_eq!(user_messages.len(), 1, "should have one user message");
4130
4131 let content = &user_messages[0].content[0];
4132 assert!(
4133 content.to_str().unwrap().contains("TODO"),
4134 "task prompt should be in user message"
4135 );
4136}
4137
4138#[gpui::test]
4139async fn test_subagent_returns_summary_on_completion(cx: &mut TestAppContext) {
4140 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4141 let fake_model = model.as_fake();
4142
4143 cx.update(|cx| {
4144 cx.update_flags(true, vec!["subagents".to_string()]);
4145 });
4146
4147 let subagent_context = SubagentContext {
4148 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4149 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4150 depth: 1,
4151 summary_prompt: "Please summarize what you found".to_string(),
4152 context_low_prompt: "Context low, wrap up".to_string(),
4153 };
4154
4155 let project = thread.read_with(cx, |t, _| t.project.clone());
4156 let project_context = cx.new(|_cx| ProjectContext::default());
4157 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4158 let context_server_registry =
4159 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4160
4161 let subagent = cx.new(|cx| {
4162 Thread::new_subagent(
4163 project.clone(),
4164 project_context,
4165 context_server_registry,
4166 Templates::new(),
4167 model.clone(),
4168 subagent_context,
4169 std::collections::BTreeMap::new(),
4170 cx,
4171 )
4172 });
4173
4174 subagent
4175 .update(cx, |thread, cx| {
4176 thread.submit_user_message("Do some work", cx)
4177 })
4178 .unwrap();
4179 cx.run_until_parked();
4180
4181 fake_model.send_last_completion_stream_text_chunk("I did the work");
4182 fake_model
4183 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4184 fake_model.end_last_completion_stream();
4185 cx.run_until_parked();
4186
4187 subagent
4188 .update(cx, |thread, cx| thread.request_final_summary(cx))
4189 .unwrap();
4190 cx.run_until_parked();
4191
4192 let pending = fake_model.pending_completions();
4193 assert!(
4194 !pending.is_empty(),
4195 "should have pending completion for summary"
4196 );
4197
4198 let messages = &pending.last().unwrap().messages;
4199 let user_messages: Vec<_> = messages
4200 .iter()
4201 .filter(|m| m.role == language_model::Role::User)
4202 .collect();
4203
4204 let last_user = user_messages.last().unwrap();
4205 assert!(
4206 last_user.content[0].to_str().unwrap().contains("summarize"),
4207 "summary prompt should be sent"
4208 );
4209}
4210
4211#[gpui::test]
4212async fn test_allowed_tools_restricts_subagent_capabilities(cx: &mut TestAppContext) {
4213 init_test(cx);
4214
4215 cx.update(|cx| {
4216 cx.update_flags(true, vec!["subagents".to_string()]);
4217 });
4218
4219 let fs = FakeFs::new(cx.executor());
4220 fs.insert_tree(path!("/test"), json!({})).await;
4221 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4222 let project_context = cx.new(|_cx| ProjectContext::default());
4223 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4224 let context_server_registry =
4225 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4226 let model = Arc::new(FakeLanguageModel::default());
4227
4228 let subagent_context = SubagentContext {
4229 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4230 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4231 depth: 1,
4232 summary_prompt: "Summarize".to_string(),
4233 context_low_prompt: "Context low".to_string(),
4234 };
4235
4236 let subagent = cx.new(|cx| {
4237 let mut thread = Thread::new_subagent(
4238 project.clone(),
4239 project_context,
4240 context_server_registry,
4241 Templates::new(),
4242 model.clone(),
4243 subagent_context,
4244 std::collections::BTreeMap::new(),
4245 cx,
4246 );
4247 thread.add_tool(EchoTool);
4248 thread.add_tool(DelayTool);
4249 thread.add_tool(WordListTool);
4250 thread
4251 });
4252
4253 subagent.read_with(cx, |thread, _| {
4254 assert!(thread.has_registered_tool("echo"));
4255 assert!(thread.has_registered_tool("delay"));
4256 assert!(thread.has_registered_tool("word_list"));
4257 });
4258
4259 let allowed: collections::HashSet<gpui::SharedString> =
4260 vec!["echo".into()].into_iter().collect();
4261
4262 subagent.update(cx, |thread, _cx| {
4263 thread.restrict_tools(&allowed);
4264 });
4265
4266 subagent.read_with(cx, |thread, _| {
4267 assert!(
4268 thread.has_registered_tool("echo"),
4269 "echo should still be available"
4270 );
4271 assert!(
4272 !thread.has_registered_tool("delay"),
4273 "delay should be removed"
4274 );
4275 assert!(
4276 !thread.has_registered_tool("word_list"),
4277 "word_list should be removed"
4278 );
4279 });
4280}
4281
4282#[gpui::test]
4283async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4284 init_test(cx);
4285
4286 cx.update(|cx| {
4287 cx.update_flags(true, vec!["subagents".to_string()]);
4288 });
4289
4290 let fs = FakeFs::new(cx.executor());
4291 fs.insert_tree(path!("/test"), json!({})).await;
4292 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4293 let project_context = cx.new(|_cx| ProjectContext::default());
4294 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4295 let context_server_registry =
4296 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4297 let model = Arc::new(FakeLanguageModel::default());
4298
4299 let parent = cx.new(|cx| {
4300 Thread::new(
4301 project.clone(),
4302 project_context.clone(),
4303 context_server_registry.clone(),
4304 Templates::new(),
4305 Some(model.clone()),
4306 cx,
4307 )
4308 });
4309
4310 let subagent_context = SubagentContext {
4311 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4312 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4313 depth: 1,
4314 summary_prompt: "Summarize".to_string(),
4315 context_low_prompt: "Context low".to_string(),
4316 };
4317
4318 let subagent = cx.new(|cx| {
4319 Thread::new_subagent(
4320 project.clone(),
4321 project_context.clone(),
4322 context_server_registry.clone(),
4323 Templates::new(),
4324 model.clone(),
4325 subagent_context,
4326 std::collections::BTreeMap::new(),
4327 cx,
4328 )
4329 });
4330
4331 parent.update(cx, |thread, _cx| {
4332 thread.register_running_subagent(subagent.downgrade());
4333 });
4334
4335 subagent
4336 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4337 .unwrap();
4338 cx.run_until_parked();
4339
4340 subagent.read_with(cx, |thread, _| {
4341 assert!(!thread.is_turn_complete(), "subagent should be running");
4342 });
4343
4344 parent.update(cx, |thread, cx| {
4345 thread.cancel(cx).detach();
4346 });
4347
4348 subagent.read_with(cx, |thread, _| {
4349 assert!(
4350 thread.is_turn_complete(),
4351 "subagent should be cancelled when parent cancels"
4352 );
4353 });
4354}
4355
4356#[gpui::test]
4357async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4358 // This test verifies that the subagent tool properly handles user cancellation
4359 // via `event_stream.cancelled_by_user()` and stops all running subagents.
4360 init_test(cx);
4361 always_allow_tools(cx);
4362
4363 cx.update(|cx| {
4364 cx.update_flags(true, vec!["subagents".to_string()]);
4365 });
4366
4367 let fs = FakeFs::new(cx.executor());
4368 fs.insert_tree(path!("/test"), json!({})).await;
4369 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4370 let project_context = cx.new(|_cx| ProjectContext::default());
4371 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4372 let context_server_registry =
4373 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4374 let model = Arc::new(FakeLanguageModel::default());
4375
4376 let parent = cx.new(|cx| {
4377 Thread::new(
4378 project.clone(),
4379 project_context.clone(),
4380 context_server_registry.clone(),
4381 Templates::new(),
4382 Some(model.clone()),
4383 cx,
4384 )
4385 });
4386
4387 let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
4388 std::collections::BTreeMap::new();
4389
4390 #[allow(clippy::arc_with_non_send_sync)]
4391 let tool = Arc::new(SubagentTool::new(
4392 parent.downgrade(),
4393 project.clone(),
4394 project_context,
4395 context_server_registry,
4396 Templates::new(),
4397 0,
4398 parent_tools,
4399 ));
4400
4401 let (event_stream, _rx, mut cancellation_tx) =
4402 crate::ToolCallEventStream::test_with_cancellation();
4403
4404 // Start the subagent tool
4405 let task = cx.update(|cx| {
4406 tool.run(
4407 SubagentToolInput {
4408 subagents: vec![crate::SubagentConfig {
4409 label: "Long running task".to_string(),
4410 task_prompt: "Do a very long task that takes forever".to_string(),
4411 summary_prompt: "Summarize".to_string(),
4412 context_low_prompt: "Context low".to_string(),
4413 timeout_ms: None,
4414 allowed_tools: None,
4415 }],
4416 },
4417 event_stream.clone(),
4418 cx,
4419 )
4420 });
4421
4422 cx.run_until_parked();
4423
4424 // Signal cancellation via the event stream
4425 crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4426
4427 // The task should complete promptly with a cancellation error
4428 let timeout = cx.background_executor.timer(Duration::from_secs(5));
4429 let result = futures::select! {
4430 result = task.fuse() => result,
4431 _ = timeout.fuse() => {
4432 panic!("subagent tool did not respond to cancellation within timeout");
4433 }
4434 };
4435
4436 // Verify we got a cancellation error
4437 let err = result.unwrap_err();
4438 assert!(
4439 err.to_string().contains("cancelled by user"),
4440 "expected cancellation error, got: {}",
4441 err
4442 );
4443}
4444
4445#[gpui::test]
4446async fn test_subagent_model_error_returned_as_tool_error(cx: &mut TestAppContext) {
4447 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4448 let fake_model = model.as_fake();
4449
4450 cx.update(|cx| {
4451 cx.update_flags(true, vec!["subagents".to_string()]);
4452 });
4453
4454 let subagent_context = SubagentContext {
4455 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4456 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4457 depth: 1,
4458 summary_prompt: "Summarize".to_string(),
4459 context_low_prompt: "Context low".to_string(),
4460 };
4461
4462 let project = thread.read_with(cx, |t, _| t.project.clone());
4463 let project_context = cx.new(|_cx| ProjectContext::default());
4464 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4465 let context_server_registry =
4466 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4467
4468 let subagent = cx.new(|cx| {
4469 Thread::new_subagent(
4470 project.clone(),
4471 project_context,
4472 context_server_registry,
4473 Templates::new(),
4474 model.clone(),
4475 subagent_context,
4476 std::collections::BTreeMap::new(),
4477 cx,
4478 )
4479 });
4480
4481 subagent
4482 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4483 .unwrap();
4484 cx.run_until_parked();
4485
4486 subagent.read_with(cx, |thread, _| {
4487 assert!(!thread.is_turn_complete(), "turn should be in progress");
4488 });
4489
4490 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::NoApiKey {
4491 provider: LanguageModelProviderName::from("Fake".to_string()),
4492 });
4493 fake_model.end_last_completion_stream();
4494 cx.run_until_parked();
4495
4496 subagent.read_with(cx, |thread, _| {
4497 assert!(
4498 thread.is_turn_complete(),
4499 "turn should be complete after non-retryable error"
4500 );
4501 });
4502}
4503
4504#[gpui::test]
4505async fn test_subagent_timeout_triggers_early_summary(cx: &mut TestAppContext) {
4506 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4507 let fake_model = model.as_fake();
4508
4509 cx.update(|cx| {
4510 cx.update_flags(true, vec!["subagents".to_string()]);
4511 });
4512
4513 let subagent_context = SubagentContext {
4514 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4515 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4516 depth: 1,
4517 summary_prompt: "Summarize your work".to_string(),
4518 context_low_prompt: "Context low, stop and summarize".to_string(),
4519 };
4520
4521 let project = thread.read_with(cx, |t, _| t.project.clone());
4522 let project_context = cx.new(|_cx| ProjectContext::default());
4523 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4524 let context_server_registry =
4525 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4526
4527 let subagent = cx.new(|cx| {
4528 Thread::new_subagent(
4529 project.clone(),
4530 project_context.clone(),
4531 context_server_registry.clone(),
4532 Templates::new(),
4533 model.clone(),
4534 subagent_context.clone(),
4535 std::collections::BTreeMap::new(),
4536 cx,
4537 )
4538 });
4539
4540 subagent.update(cx, |thread, _| {
4541 thread.add_tool(EchoTool);
4542 });
4543
4544 subagent
4545 .update(cx, |thread, cx| {
4546 thread.submit_user_message("Do some work", cx)
4547 })
4548 .unwrap();
4549 cx.run_until_parked();
4550
4551 fake_model.send_last_completion_stream_text_chunk("Working on it...");
4552 fake_model
4553 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4554 fake_model.end_last_completion_stream();
4555 cx.run_until_parked();
4556
4557 let interrupt_result = subagent.update(cx, |thread, cx| thread.interrupt_for_summary(cx));
4558 assert!(
4559 interrupt_result.is_ok(),
4560 "interrupt_for_summary should succeed"
4561 );
4562
4563 cx.run_until_parked();
4564
4565 let pending = fake_model.pending_completions();
4566 assert!(
4567 !pending.is_empty(),
4568 "should have pending completion for interrupted summary"
4569 );
4570
4571 let messages = &pending.last().unwrap().messages;
4572 let user_messages: Vec<_> = messages
4573 .iter()
4574 .filter(|m| m.role == language_model::Role::User)
4575 .collect();
4576
4577 let last_user = user_messages.last().unwrap();
4578 let content_str = last_user.content[0].to_str().unwrap();
4579 assert!(
4580 content_str.contains("Context low") || content_str.contains("stop and summarize"),
4581 "context_low_prompt should be sent when interrupting: got {:?}",
4582 content_str
4583 );
4584}
4585
4586#[gpui::test]
4587async fn test_context_low_check_returns_true_when_usage_high(cx: &mut TestAppContext) {
4588 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4589 let fake_model = model.as_fake();
4590
4591 cx.update(|cx| {
4592 cx.update_flags(true, vec!["subagents".to_string()]);
4593 });
4594
4595 let subagent_context = SubagentContext {
4596 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4597 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4598 depth: 1,
4599 summary_prompt: "Summarize".to_string(),
4600 context_low_prompt: "Context low".to_string(),
4601 };
4602
4603 let project = thread.read_with(cx, |t, _| t.project.clone());
4604 let project_context = cx.new(|_cx| ProjectContext::default());
4605 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4606 let context_server_registry =
4607 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4608
4609 let subagent = cx.new(|cx| {
4610 Thread::new_subagent(
4611 project.clone(),
4612 project_context,
4613 context_server_registry,
4614 Templates::new(),
4615 model.clone(),
4616 subagent_context,
4617 std::collections::BTreeMap::new(),
4618 cx,
4619 )
4620 });
4621
4622 subagent
4623 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4624 .unwrap();
4625 cx.run_until_parked();
4626
4627 let max_tokens = model.max_token_count();
4628 let high_usage = language_model::TokenUsage {
4629 input_tokens: (max_tokens as f64 * 0.80) as u64,
4630 output_tokens: 0,
4631 cache_creation_input_tokens: 0,
4632 cache_read_input_tokens: 0,
4633 };
4634
4635 fake_model
4636 .send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(high_usage));
4637 fake_model.send_last_completion_stream_text_chunk("Working...");
4638 fake_model
4639 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4640 fake_model.end_last_completion_stream();
4641 cx.run_until_parked();
4642
4643 let usage = subagent.read_with(cx, |thread, _| thread.latest_token_usage());
4644 assert!(usage.is_some(), "should have token usage after completion");
4645
4646 let usage = usage.unwrap();
4647 let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
4648 assert!(
4649 remaining_ratio <= 0.25,
4650 "remaining ratio should be at or below 25% (got {}%), indicating context is low",
4651 remaining_ratio * 100.0
4652 );
4653}
4654
4655#[gpui::test]
4656async fn test_allowed_tools_rejects_unknown_tool(cx: &mut TestAppContext) {
4657 init_test(cx);
4658
4659 cx.update(|cx| {
4660 cx.update_flags(true, vec!["subagents".to_string()]);
4661 });
4662
4663 let fs = FakeFs::new(cx.executor());
4664 fs.insert_tree(path!("/test"), json!({})).await;
4665 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4666 let project_context = cx.new(|_cx| ProjectContext::default());
4667 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4668 let context_server_registry =
4669 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4670 let model = Arc::new(FakeLanguageModel::default());
4671
4672 let parent = cx.new(|cx| {
4673 let mut thread = Thread::new(
4674 project.clone(),
4675 project_context.clone(),
4676 context_server_registry.clone(),
4677 Templates::new(),
4678 Some(model.clone()),
4679 cx,
4680 );
4681 thread.add_tool(EchoTool);
4682 thread
4683 });
4684
4685 let mut parent_tools: std::collections::BTreeMap<
4686 gpui::SharedString,
4687 Arc<dyn crate::AnyAgentTool>,
4688 > = std::collections::BTreeMap::new();
4689 parent_tools.insert("echo".into(), EchoTool.erase());
4690
4691 #[allow(clippy::arc_with_non_send_sync)]
4692 let tool = Arc::new(SubagentTool::new(
4693 parent.downgrade(),
4694 project,
4695 project_context,
4696 context_server_registry,
4697 Templates::new(),
4698 0,
4699 parent_tools,
4700 ));
4701
4702 let subagent_configs = vec![crate::SubagentConfig {
4703 label: "Test".to_string(),
4704 task_prompt: "Do something".to_string(),
4705 summary_prompt: "Summarize".to_string(),
4706 context_low_prompt: "Context low".to_string(),
4707 timeout_ms: None,
4708 allowed_tools: Some(vec!["nonexistent_tool".to_string()]),
4709 }];
4710 let result = tool.validate_subagents(&subagent_configs);
4711 assert!(result.is_err(), "should reject unknown tool");
4712 let err_msg = result.unwrap_err().to_string();
4713 assert!(
4714 err_msg.contains("nonexistent_tool"),
4715 "error should mention the invalid tool name: {}",
4716 err_msg
4717 );
4718 assert!(
4719 err_msg.contains("do not exist"),
4720 "error should explain the tool does not exist: {}",
4721 err_msg
4722 );
4723}
4724
4725#[gpui::test]
4726async fn test_subagent_empty_response_handled(cx: &mut TestAppContext) {
4727 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4728 let fake_model = model.as_fake();
4729
4730 cx.update(|cx| {
4731 cx.update_flags(true, vec!["subagents".to_string()]);
4732 });
4733
4734 let subagent_context = SubagentContext {
4735 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4736 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4737 depth: 1,
4738 summary_prompt: "Summarize".to_string(),
4739 context_low_prompt: "Context low".to_string(),
4740 };
4741
4742 let project = thread.read_with(cx, |t, _| t.project.clone());
4743 let project_context = cx.new(|_cx| ProjectContext::default());
4744 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4745 let context_server_registry =
4746 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4747
4748 let subagent = cx.new(|cx| {
4749 Thread::new_subagent(
4750 project.clone(),
4751 project_context,
4752 context_server_registry,
4753 Templates::new(),
4754 model.clone(),
4755 subagent_context,
4756 std::collections::BTreeMap::new(),
4757 cx,
4758 )
4759 });
4760
4761 subagent
4762 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4763 .unwrap();
4764 cx.run_until_parked();
4765
4766 fake_model
4767 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4768 fake_model.end_last_completion_stream();
4769 cx.run_until_parked();
4770
4771 subagent.read_with(cx, |thread, _| {
4772 assert!(
4773 thread.is_turn_complete(),
4774 "turn should complete even with empty response"
4775 );
4776 });
4777}
4778
4779#[gpui::test]
4780async fn test_nested_subagent_at_depth_2_succeeds(cx: &mut TestAppContext) {
4781 init_test(cx);
4782
4783 cx.update(|cx| {
4784 cx.update_flags(true, vec!["subagents".to_string()]);
4785 });
4786
4787 let fs = FakeFs::new(cx.executor());
4788 fs.insert_tree(path!("/test"), json!({})).await;
4789 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4790 let project_context = cx.new(|_cx| ProjectContext::default());
4791 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4792 let context_server_registry =
4793 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4794 let model = Arc::new(FakeLanguageModel::default());
4795
4796 let depth_1_context = SubagentContext {
4797 parent_thread_id: agent_client_protocol::SessionId::new("root-id"),
4798 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-1"),
4799 depth: 1,
4800 summary_prompt: "Summarize".to_string(),
4801 context_low_prompt: "Context low".to_string(),
4802 };
4803
4804 let depth_1_subagent = cx.new(|cx| {
4805 Thread::new_subagent(
4806 project.clone(),
4807 project_context.clone(),
4808 context_server_registry.clone(),
4809 Templates::new(),
4810 model.clone(),
4811 depth_1_context,
4812 std::collections::BTreeMap::new(),
4813 cx,
4814 )
4815 });
4816
4817 depth_1_subagent.read_with(cx, |thread, _| {
4818 assert_eq!(thread.depth(), 1);
4819 assert!(thread.is_subagent());
4820 });
4821
4822 let depth_2_context = SubagentContext {
4823 parent_thread_id: agent_client_protocol::SessionId::new("depth-1-id"),
4824 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-2"),
4825 depth: 2,
4826 summary_prompt: "Summarize depth 2".to_string(),
4827 context_low_prompt: "Context low depth 2".to_string(),
4828 };
4829
4830 let depth_2_subagent = cx.new(|cx| {
4831 Thread::new_subagent(
4832 project.clone(),
4833 project_context.clone(),
4834 context_server_registry.clone(),
4835 Templates::new(),
4836 model.clone(),
4837 depth_2_context,
4838 std::collections::BTreeMap::new(),
4839 cx,
4840 )
4841 });
4842
4843 depth_2_subagent.read_with(cx, |thread, _| {
4844 assert_eq!(thread.depth(), 2);
4845 assert!(thread.is_subagent());
4846 });
4847
4848 depth_2_subagent
4849 .update(cx, |thread, cx| {
4850 thread.submit_user_message("Nested task", cx)
4851 })
4852 .unwrap();
4853 cx.run_until_parked();
4854
4855 let pending = model.as_fake().pending_completions();
4856 assert!(
4857 !pending.is_empty(),
4858 "depth-2 subagent should be able to submit messages"
4859 );
4860}
4861
4862#[gpui::test]
4863async fn test_subagent_uses_tool_and_returns_result(cx: &mut TestAppContext) {
4864 init_test(cx);
4865 always_allow_tools(cx);
4866
4867 cx.update(|cx| {
4868 cx.update_flags(true, vec!["subagents".to_string()]);
4869 });
4870
4871 let fs = FakeFs::new(cx.executor());
4872 fs.insert_tree(path!("/test"), json!({})).await;
4873 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4874 let project_context = cx.new(|_cx| ProjectContext::default());
4875 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4876 let context_server_registry =
4877 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4878 let model = Arc::new(FakeLanguageModel::default());
4879 let fake_model = model.as_fake();
4880
4881 let subagent_context = SubagentContext {
4882 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4883 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4884 depth: 1,
4885 summary_prompt: "Summarize what you did".to_string(),
4886 context_low_prompt: "Context low".to_string(),
4887 };
4888
4889 let subagent = cx.new(|cx| {
4890 let mut thread = Thread::new_subagent(
4891 project.clone(),
4892 project_context,
4893 context_server_registry,
4894 Templates::new(),
4895 model.clone(),
4896 subagent_context,
4897 std::collections::BTreeMap::new(),
4898 cx,
4899 );
4900 thread.add_tool(EchoTool);
4901 thread
4902 });
4903
4904 subagent.read_with(cx, |thread, _| {
4905 assert!(
4906 thread.has_registered_tool("echo"),
4907 "subagent should have echo tool"
4908 );
4909 });
4910
4911 subagent
4912 .update(cx, |thread, cx| {
4913 thread.submit_user_message("Use the echo tool to echo 'hello world'", cx)
4914 })
4915 .unwrap();
4916 cx.run_until_parked();
4917
4918 let tool_use = LanguageModelToolUse {
4919 id: "tool_call_1".into(),
4920 name: EchoTool::name().into(),
4921 raw_input: json!({"text": "hello world"}).to_string(),
4922 input: json!({"text": "hello world"}),
4923 is_input_complete: true,
4924 thought_signature: None,
4925 };
4926 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
4927 fake_model.end_last_completion_stream();
4928 cx.run_until_parked();
4929
4930 let pending = fake_model.pending_completions();
4931 assert!(
4932 !pending.is_empty(),
4933 "should have pending completion after tool use"
4934 );
4935
4936 let last_completion = pending.last().unwrap();
4937 let has_tool_result = last_completion.messages.iter().any(|m| {
4938 m.content
4939 .iter()
4940 .any(|c| matches!(c, MessageContent::ToolResult(_)))
4941 });
4942 assert!(
4943 has_tool_result,
4944 "tool result should be in the messages sent back to the model"
4945 );
4946}
4947
4948#[gpui::test]
4949async fn test_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4950 init_test(cx);
4951
4952 cx.update(|cx| {
4953 cx.update_flags(true, vec!["subagents".to_string()]);
4954 });
4955
4956 let fs = FakeFs::new(cx.executor());
4957 fs.insert_tree(path!("/test"), json!({})).await;
4958 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4959 let project_context = cx.new(|_cx| ProjectContext::default());
4960 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4961 let context_server_registry =
4962 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4963 let model = Arc::new(FakeLanguageModel::default());
4964
4965 let parent = cx.new(|cx| {
4966 Thread::new(
4967 project.clone(),
4968 project_context.clone(),
4969 context_server_registry.clone(),
4970 Templates::new(),
4971 Some(model.clone()),
4972 cx,
4973 )
4974 });
4975
4976 let mut subagents = Vec::new();
4977 for i in 0..MAX_PARALLEL_SUBAGENTS {
4978 let subagent_context = SubagentContext {
4979 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4980 tool_use_id: language_model::LanguageModelToolUseId::from(format!("tool-use-{}", i)),
4981 depth: 1,
4982 summary_prompt: "Summarize".to_string(),
4983 context_low_prompt: "Context low".to_string(),
4984 };
4985
4986 let subagent = cx.new(|cx| {
4987 Thread::new_subagent(
4988 project.clone(),
4989 project_context.clone(),
4990 context_server_registry.clone(),
4991 Templates::new(),
4992 model.clone(),
4993 subagent_context,
4994 std::collections::BTreeMap::new(),
4995 cx,
4996 )
4997 });
4998
4999 parent.update(cx, |thread, _cx| {
5000 thread.register_running_subagent(subagent.downgrade());
5001 });
5002 subagents.push(subagent);
5003 }
5004
5005 parent.read_with(cx, |thread, _| {
5006 assert_eq!(
5007 thread.running_subagent_count(),
5008 MAX_PARALLEL_SUBAGENTS,
5009 "should have MAX_PARALLEL_SUBAGENTS registered"
5010 );
5011 });
5012
5013 let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
5014 std::collections::BTreeMap::new();
5015
5016 #[allow(clippy::arc_with_non_send_sync)]
5017 let tool = Arc::new(SubagentTool::new(
5018 parent.downgrade(),
5019 project.clone(),
5020 project_context,
5021 context_server_registry,
5022 Templates::new(),
5023 0,
5024 parent_tools,
5025 ));
5026
5027 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5028
5029 let result = cx.update(|cx| {
5030 tool.run(
5031 SubagentToolInput {
5032 subagents: vec![crate::SubagentConfig {
5033 label: "Test".to_string(),
5034 task_prompt: "Do something".to_string(),
5035 summary_prompt: "Summarize".to_string(),
5036 context_low_prompt: "Context low".to_string(),
5037 timeout_ms: None,
5038 allowed_tools: None,
5039 }],
5040 },
5041 event_stream,
5042 cx,
5043 )
5044 });
5045
5046 let err = result.await.unwrap_err();
5047 assert!(
5048 err.to_string().contains("Maximum parallel subagents"),
5049 "should reject when max parallel subagents reached: {}",
5050 err
5051 );
5052
5053 drop(subagents);
5054}
5055
5056#[gpui::test]
5057async fn test_subagent_tool_end_to_end(cx: &mut TestAppContext) {
5058 init_test(cx);
5059 always_allow_tools(cx);
5060
5061 cx.update(|cx| {
5062 cx.update_flags(true, vec!["subagents".to_string()]);
5063 });
5064
5065 let fs = FakeFs::new(cx.executor());
5066 fs.insert_tree(path!("/test"), json!({})).await;
5067 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
5068 let project_context = cx.new(|_cx| ProjectContext::default());
5069 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5070 let context_server_registry =
5071 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5072 let model = Arc::new(FakeLanguageModel::default());
5073 let fake_model = model.as_fake();
5074
5075 let parent = cx.new(|cx| {
5076 let mut thread = Thread::new(
5077 project.clone(),
5078 project_context.clone(),
5079 context_server_registry.clone(),
5080 Templates::new(),
5081 Some(model.clone()),
5082 cx,
5083 );
5084 thread.add_tool(EchoTool);
5085 thread
5086 });
5087
5088 let mut parent_tools: std::collections::BTreeMap<
5089 gpui::SharedString,
5090 Arc<dyn crate::AnyAgentTool>,
5091 > = std::collections::BTreeMap::new();
5092 parent_tools.insert("echo".into(), EchoTool.erase());
5093
5094 #[allow(clippy::arc_with_non_send_sync)]
5095 let tool = Arc::new(SubagentTool::new(
5096 parent.downgrade(),
5097 project.clone(),
5098 project_context,
5099 context_server_registry,
5100 Templates::new(),
5101 0,
5102 parent_tools,
5103 ));
5104
5105 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5106
5107 let task = cx.update(|cx| {
5108 tool.run(
5109 SubagentToolInput {
5110 subagents: vec![crate::SubagentConfig {
5111 label: "Research task".to_string(),
5112 task_prompt: "Find all TODOs in the codebase".to_string(),
5113 summary_prompt: "Summarize what you found".to_string(),
5114 context_low_prompt: "Context low, wrap up".to_string(),
5115 timeout_ms: None,
5116 allowed_tools: None,
5117 }],
5118 },
5119 event_stream,
5120 cx,
5121 )
5122 });
5123
5124 cx.run_until_parked();
5125
5126 let pending = fake_model.pending_completions();
5127 assert!(
5128 !pending.is_empty(),
5129 "subagent should have started and sent a completion request"
5130 );
5131
5132 let first_completion = &pending[0];
5133 let has_task_prompt = first_completion.messages.iter().any(|m| {
5134 m.role == language_model::Role::User
5135 && m.content
5136 .iter()
5137 .any(|c| c.to_str().map(|s| s.contains("TODO")).unwrap_or(false))
5138 });
5139 assert!(has_task_prompt, "task prompt should be sent to subagent");
5140
5141 fake_model.send_last_completion_stream_text_chunk("I found 5 TODOs in the codebase.");
5142 fake_model
5143 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
5144 fake_model.end_last_completion_stream();
5145 cx.run_until_parked();
5146
5147 let pending = fake_model.pending_completions();
5148 assert!(
5149 !pending.is_empty(),
5150 "should have pending completion for summary request"
5151 );
5152
5153 let last_completion = pending.last().unwrap();
5154 let has_summary_prompt = last_completion.messages.iter().any(|m| {
5155 m.role == language_model::Role::User
5156 && m.content.iter().any(|c| {
5157 c.to_str()
5158 .map(|s| s.contains("Summarize") || s.contains("summarize"))
5159 .unwrap_or(false)
5160 })
5161 });
5162 assert!(
5163 has_summary_prompt,
5164 "summary prompt should be sent after task completion"
5165 );
5166
5167 fake_model.send_last_completion_stream_text_chunk("Summary: Found 5 TODOs across 3 files.");
5168 fake_model
5169 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
5170 fake_model.end_last_completion_stream();
5171 cx.run_until_parked();
5172
5173 let result = task.await;
5174 assert!(result.is_ok(), "subagent tool should complete successfully");
5175
5176 let summary = result.unwrap();
5177 assert!(
5178 summary.contains("Summary") || summary.contains("TODO") || summary.contains("5"),
5179 "summary should contain subagent's response: {}",
5180 summary
5181 );
5182}
5183
5184#[gpui::test]
5185async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5186 init_test(cx);
5187
5188 let fs = FakeFs::new(cx.executor());
5189 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5190 .await;
5191 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5192
5193 cx.update(|cx| {
5194 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5195 settings.tool_permissions.tools.insert(
5196 "edit_file".into(),
5197 agent_settings::ToolRules {
5198 default_mode: settings::ToolPermissionMode::Allow,
5199 always_allow: vec![],
5200 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5201 always_confirm: vec![],
5202 invalid_patterns: vec![],
5203 },
5204 );
5205 agent_settings::AgentSettings::override_global(settings, cx);
5206 });
5207
5208 let context_server_registry =
5209 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5210 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5211 let templates = crate::Templates::new();
5212 let thread = cx.new(|cx| {
5213 crate::Thread::new(
5214 project.clone(),
5215 cx.new(|_cx| prompt_store::ProjectContext::default()),
5216 context_server_registry,
5217 templates.clone(),
5218 None,
5219 cx,
5220 )
5221 });
5222
5223 #[allow(clippy::arc_with_non_send_sync)]
5224 let tool = Arc::new(crate::EditFileTool::new(
5225 project.clone(),
5226 thread.downgrade(),
5227 language_registry,
5228 templates,
5229 ));
5230 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5231
5232 let task = cx.update(|cx| {
5233 tool.run(
5234 crate::EditFileToolInput {
5235 display_description: "Edit sensitive file".to_string(),
5236 path: "root/sensitive_config.txt".into(),
5237 mode: crate::EditFileMode::Edit,
5238 },
5239 event_stream,
5240 cx,
5241 )
5242 });
5243
5244 let result = task.await;
5245 assert!(result.is_err(), "expected edit to be blocked");
5246 assert!(
5247 result.unwrap_err().to_string().contains("blocked"),
5248 "error should mention the edit was blocked"
5249 );
5250}
5251
5252#[gpui::test]
5253async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5254 init_test(cx);
5255
5256 let fs = FakeFs::new(cx.executor());
5257 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5258 .await;
5259 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5260
5261 cx.update(|cx| {
5262 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5263 settings.tool_permissions.tools.insert(
5264 "delete_path".into(),
5265 agent_settings::ToolRules {
5266 default_mode: settings::ToolPermissionMode::Allow,
5267 always_allow: vec![],
5268 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5269 always_confirm: vec![],
5270 invalid_patterns: vec![],
5271 },
5272 );
5273 agent_settings::AgentSettings::override_global(settings, cx);
5274 });
5275
5276 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5277
5278 #[allow(clippy::arc_with_non_send_sync)]
5279 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5280 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5281
5282 let task = cx.update(|cx| {
5283 tool.run(
5284 crate::DeletePathToolInput {
5285 path: "root/important_data.txt".to_string(),
5286 },
5287 event_stream,
5288 cx,
5289 )
5290 });
5291
5292 let result = task.await;
5293 assert!(result.is_err(), "expected deletion to be blocked");
5294 assert!(
5295 result.unwrap_err().to_string().contains("blocked"),
5296 "error should mention the deletion was blocked"
5297 );
5298}
5299
5300#[gpui::test]
5301async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5302 init_test(cx);
5303
5304 let fs = FakeFs::new(cx.executor());
5305 fs.insert_tree(
5306 "/root",
5307 json!({
5308 "safe.txt": "content",
5309 "protected": {}
5310 }),
5311 )
5312 .await;
5313 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5314
5315 cx.update(|cx| {
5316 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5317 settings.tool_permissions.tools.insert(
5318 "move_path".into(),
5319 agent_settings::ToolRules {
5320 default_mode: settings::ToolPermissionMode::Allow,
5321 always_allow: vec![],
5322 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5323 always_confirm: vec![],
5324 invalid_patterns: vec![],
5325 },
5326 );
5327 agent_settings::AgentSettings::override_global(settings, cx);
5328 });
5329
5330 #[allow(clippy::arc_with_non_send_sync)]
5331 let tool = Arc::new(crate::MovePathTool::new(project));
5332 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5333
5334 let task = cx.update(|cx| {
5335 tool.run(
5336 crate::MovePathToolInput {
5337 source_path: "root/safe.txt".to_string(),
5338 destination_path: "root/protected/safe.txt".to_string(),
5339 },
5340 event_stream,
5341 cx,
5342 )
5343 });
5344
5345 let result = task.await;
5346 assert!(
5347 result.is_err(),
5348 "expected move to be blocked due to destination path"
5349 );
5350 assert!(
5351 result.unwrap_err().to_string().contains("blocked"),
5352 "error should mention the move was blocked"
5353 );
5354}
5355
5356#[gpui::test]
5357async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5358 init_test(cx);
5359
5360 let fs = FakeFs::new(cx.executor());
5361 fs.insert_tree(
5362 "/root",
5363 json!({
5364 "secret.txt": "secret content",
5365 "public": {}
5366 }),
5367 )
5368 .await;
5369 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5370
5371 cx.update(|cx| {
5372 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5373 settings.tool_permissions.tools.insert(
5374 "move_path".into(),
5375 agent_settings::ToolRules {
5376 default_mode: settings::ToolPermissionMode::Allow,
5377 always_allow: vec![],
5378 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5379 always_confirm: vec![],
5380 invalid_patterns: vec![],
5381 },
5382 );
5383 agent_settings::AgentSettings::override_global(settings, cx);
5384 });
5385
5386 #[allow(clippy::arc_with_non_send_sync)]
5387 let tool = Arc::new(crate::MovePathTool::new(project));
5388 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5389
5390 let task = cx.update(|cx| {
5391 tool.run(
5392 crate::MovePathToolInput {
5393 source_path: "root/secret.txt".to_string(),
5394 destination_path: "root/public/not_secret.txt".to_string(),
5395 },
5396 event_stream,
5397 cx,
5398 )
5399 });
5400
5401 let result = task.await;
5402 assert!(
5403 result.is_err(),
5404 "expected move to be blocked due to source path"
5405 );
5406 assert!(
5407 result.unwrap_err().to_string().contains("blocked"),
5408 "error should mention the move was blocked"
5409 );
5410}
5411
5412#[gpui::test]
5413async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5414 init_test(cx);
5415
5416 let fs = FakeFs::new(cx.executor());
5417 fs.insert_tree(
5418 "/root",
5419 json!({
5420 "confidential.txt": "confidential data",
5421 "dest": {}
5422 }),
5423 )
5424 .await;
5425 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5426
5427 cx.update(|cx| {
5428 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5429 settings.tool_permissions.tools.insert(
5430 "copy_path".into(),
5431 agent_settings::ToolRules {
5432 default_mode: settings::ToolPermissionMode::Allow,
5433 always_allow: vec![],
5434 always_deny: vec![
5435 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5436 ],
5437 always_confirm: vec![],
5438 invalid_patterns: vec![],
5439 },
5440 );
5441 agent_settings::AgentSettings::override_global(settings, cx);
5442 });
5443
5444 #[allow(clippy::arc_with_non_send_sync)]
5445 let tool = Arc::new(crate::CopyPathTool::new(project));
5446 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5447
5448 let task = cx.update(|cx| {
5449 tool.run(
5450 crate::CopyPathToolInput {
5451 source_path: "root/confidential.txt".to_string(),
5452 destination_path: "root/dest/copy.txt".to_string(),
5453 },
5454 event_stream,
5455 cx,
5456 )
5457 });
5458
5459 let result = task.await;
5460 assert!(result.is_err(), "expected copy to be blocked");
5461 assert!(
5462 result.unwrap_err().to_string().contains("blocked"),
5463 "error should mention the copy was blocked"
5464 );
5465}
5466
5467#[gpui::test]
5468async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5469 init_test(cx);
5470
5471 let fs = FakeFs::new(cx.executor());
5472 fs.insert_tree(
5473 "/root",
5474 json!({
5475 "normal.txt": "normal content",
5476 "readonly": {
5477 "config.txt": "readonly content"
5478 }
5479 }),
5480 )
5481 .await;
5482 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5483
5484 cx.update(|cx| {
5485 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5486 settings.tool_permissions.tools.insert(
5487 "save_file".into(),
5488 agent_settings::ToolRules {
5489 default_mode: settings::ToolPermissionMode::Allow,
5490 always_allow: vec![],
5491 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5492 always_confirm: vec![],
5493 invalid_patterns: vec![],
5494 },
5495 );
5496 agent_settings::AgentSettings::override_global(settings, cx);
5497 });
5498
5499 #[allow(clippy::arc_with_non_send_sync)]
5500 let tool = Arc::new(crate::SaveFileTool::new(project));
5501 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5502
5503 let task = cx.update(|cx| {
5504 tool.run(
5505 crate::SaveFileToolInput {
5506 paths: vec![
5507 std::path::PathBuf::from("root/normal.txt"),
5508 std::path::PathBuf::from("root/readonly/config.txt"),
5509 ],
5510 },
5511 event_stream,
5512 cx,
5513 )
5514 });
5515
5516 let result = task.await;
5517 assert!(
5518 result.is_err(),
5519 "expected save to be blocked due to denied path"
5520 );
5521 assert!(
5522 result.unwrap_err().to_string().contains("blocked"),
5523 "error should mention the save was blocked"
5524 );
5525}
5526
5527#[gpui::test]
5528async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5529 init_test(cx);
5530
5531 let fs = FakeFs::new(cx.executor());
5532 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5533 .await;
5534 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5535
5536 cx.update(|cx| {
5537 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5538 settings.always_allow_tool_actions = false;
5539 settings.tool_permissions.tools.insert(
5540 "save_file".into(),
5541 agent_settings::ToolRules {
5542 default_mode: settings::ToolPermissionMode::Allow,
5543 always_allow: vec![],
5544 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5545 always_confirm: vec![],
5546 invalid_patterns: vec![],
5547 },
5548 );
5549 agent_settings::AgentSettings::override_global(settings, cx);
5550 });
5551
5552 #[allow(clippy::arc_with_non_send_sync)]
5553 let tool = Arc::new(crate::SaveFileTool::new(project));
5554 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5555
5556 let task = cx.update(|cx| {
5557 tool.run(
5558 crate::SaveFileToolInput {
5559 paths: vec![std::path::PathBuf::from("root/config.secret")],
5560 },
5561 event_stream,
5562 cx,
5563 )
5564 });
5565
5566 let result = task.await;
5567 assert!(result.is_err(), "expected save to be blocked");
5568 assert!(
5569 result.unwrap_err().to_string().contains("blocked"),
5570 "error should mention the save was blocked"
5571 );
5572}
5573
5574#[gpui::test]
5575async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5576 init_test(cx);
5577
5578 cx.update(|cx| {
5579 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5580 settings.tool_permissions.tools.insert(
5581 "web_search".into(),
5582 agent_settings::ToolRules {
5583 default_mode: settings::ToolPermissionMode::Allow,
5584 always_allow: vec![],
5585 always_deny: vec![
5586 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5587 ],
5588 always_confirm: vec![],
5589 invalid_patterns: vec![],
5590 },
5591 );
5592 agent_settings::AgentSettings::override_global(settings, cx);
5593 });
5594
5595 #[allow(clippy::arc_with_non_send_sync)]
5596 let tool = Arc::new(crate::WebSearchTool);
5597 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5598
5599 let input: crate::WebSearchToolInput =
5600 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5601
5602 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5603
5604 let result = task.await;
5605 assert!(result.is_err(), "expected search to be blocked");
5606 assert!(
5607 result.unwrap_err().to_string().contains("blocked"),
5608 "error should mention the search was blocked"
5609 );
5610}
5611
5612#[gpui::test]
5613async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5614 init_test(cx);
5615
5616 let fs = FakeFs::new(cx.executor());
5617 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5618 .await;
5619 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5620
5621 cx.update(|cx| {
5622 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5623 settings.always_allow_tool_actions = false;
5624 settings.tool_permissions.tools.insert(
5625 "edit_file".into(),
5626 agent_settings::ToolRules {
5627 default_mode: settings::ToolPermissionMode::Confirm,
5628 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5629 always_deny: vec![],
5630 always_confirm: vec![],
5631 invalid_patterns: vec![],
5632 },
5633 );
5634 agent_settings::AgentSettings::override_global(settings, cx);
5635 });
5636
5637 let context_server_registry =
5638 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5639 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5640 let templates = crate::Templates::new();
5641 let thread = cx.new(|cx| {
5642 crate::Thread::new(
5643 project.clone(),
5644 cx.new(|_cx| prompt_store::ProjectContext::default()),
5645 context_server_registry,
5646 templates.clone(),
5647 None,
5648 cx,
5649 )
5650 });
5651
5652 #[allow(clippy::arc_with_non_send_sync)]
5653 let tool = Arc::new(crate::EditFileTool::new(
5654 project,
5655 thread.downgrade(),
5656 language_registry,
5657 templates,
5658 ));
5659 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5660
5661 let _task = cx.update(|cx| {
5662 tool.run(
5663 crate::EditFileToolInput {
5664 display_description: "Edit README".to_string(),
5665 path: "root/README.md".into(),
5666 mode: crate::EditFileMode::Edit,
5667 },
5668 event_stream,
5669 cx,
5670 )
5671 });
5672
5673 cx.run_until_parked();
5674
5675 let event = rx.try_next();
5676 assert!(
5677 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5678 "expected no authorization request for allowed .md file"
5679 );
5680}
5681
5682#[gpui::test]
5683async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5684 init_test(cx);
5685
5686 cx.update(|cx| {
5687 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5688 settings.tool_permissions.tools.insert(
5689 "fetch".into(),
5690 agent_settings::ToolRules {
5691 default_mode: settings::ToolPermissionMode::Allow,
5692 always_allow: vec![],
5693 always_deny: vec![
5694 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5695 ],
5696 always_confirm: vec![],
5697 invalid_patterns: vec![],
5698 },
5699 );
5700 agent_settings::AgentSettings::override_global(settings, cx);
5701 });
5702
5703 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5704
5705 #[allow(clippy::arc_with_non_send_sync)]
5706 let tool = Arc::new(crate::FetchTool::new(http_client));
5707 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5708
5709 let input: crate::FetchToolInput =
5710 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5711
5712 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5713
5714 let result = task.await;
5715 assert!(result.is_err(), "expected fetch to be blocked");
5716 assert!(
5717 result.unwrap_err().to_string().contains("blocked"),
5718 "error should mention the fetch was blocked"
5719 );
5720}
5721
5722#[gpui::test]
5723async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5724 init_test(cx);
5725
5726 cx.update(|cx| {
5727 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5728 settings.always_allow_tool_actions = false;
5729 settings.tool_permissions.tools.insert(
5730 "fetch".into(),
5731 agent_settings::ToolRules {
5732 default_mode: settings::ToolPermissionMode::Confirm,
5733 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5734 always_deny: vec![],
5735 always_confirm: vec![],
5736 invalid_patterns: vec![],
5737 },
5738 );
5739 agent_settings::AgentSettings::override_global(settings, cx);
5740 });
5741
5742 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5743
5744 #[allow(clippy::arc_with_non_send_sync)]
5745 let tool = Arc::new(crate::FetchTool::new(http_client));
5746 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5747
5748 let input: crate::FetchToolInput =
5749 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5750
5751 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5752
5753 cx.run_until_parked();
5754
5755 let event = rx.try_next();
5756 assert!(
5757 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5758 "expected no authorization request for allowed docs.rs URL"
5759 );
5760}