1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use feature_flags::FeatureFlagAppExt as _;
11use fs::{FakeFs, Fs};
12use futures::{
13 FutureExt as _, StreamExt,
14 channel::{
15 mpsc::{self, UnboundedReceiver},
16 oneshot,
17 },
18 future::{Fuse, Shared},
19};
20use gpui::{
21 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
22 http_client::FakeHttpClient,
23};
24use indoc::indoc;
25use language_model::{
26 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
27 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
28 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
29 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
30};
31use pretty_assertions::assert_eq;
32use project::{
33 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
34};
35use prompt_store::ProjectContext;
36use reqwest_client::ReqwestClient;
37use schemars::JsonSchema;
38use serde::{Deserialize, Serialize};
39use serde_json::json;
40use settings::{Settings, SettingsStore};
41use std::{
42 path::Path,
43 pin::Pin,
44 rc::Rc,
45 sync::{
46 Arc,
47 atomic::{AtomicBool, Ordering},
48 },
49 time::Duration,
50};
51use util::path;
52
53mod test_tools;
54use test_tools::*;
55
56fn init_test(cx: &mut TestAppContext) {
57 cx.update(|cx| {
58 let settings_store = SettingsStore::test(cx);
59 cx.set_global(settings_store);
60 });
61}
62
63struct FakeTerminalHandle {
64 killed: Arc<AtomicBool>,
65 stopped_by_user: Arc<AtomicBool>,
66 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
67 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
68 output: acp::TerminalOutputResponse,
69 id: acp::TerminalId,
70}
71
72impl FakeTerminalHandle {
73 fn new_never_exits(cx: &mut App) -> Self {
74 let killed = Arc::new(AtomicBool::new(false));
75 let stopped_by_user = Arc::new(AtomicBool::new(false));
76
77 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
78
79 let wait_for_exit = cx
80 .spawn(async move |_cx| {
81 // Wait for the exit signal (sent when kill() is called)
82 let _ = exit_receiver.await;
83 acp::TerminalExitStatus::new()
84 })
85 .shared();
86
87 Self {
88 killed,
89 stopped_by_user,
90 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
91 wait_for_exit,
92 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
93 id: acp::TerminalId::new("fake_terminal".to_string()),
94 }
95 }
96
97 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
98 let killed = Arc::new(AtomicBool::new(false));
99 let stopped_by_user = Arc::new(AtomicBool::new(false));
100 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
101
102 let wait_for_exit = cx
103 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
104 .shared();
105
106 Self {
107 killed,
108 stopped_by_user,
109 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
110 wait_for_exit,
111 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
112 id: acp::TerminalId::new("fake_terminal".to_string()),
113 }
114 }
115
116 fn was_killed(&self) -> bool {
117 self.killed.load(Ordering::SeqCst)
118 }
119
120 fn set_stopped_by_user(&self, stopped: bool) {
121 self.stopped_by_user.store(stopped, Ordering::SeqCst);
122 }
123
124 fn signal_exit(&self) {
125 if let Some(sender) = self.exit_sender.borrow_mut().take() {
126 let _ = sender.send(());
127 }
128 }
129}
130
131impl crate::TerminalHandle for FakeTerminalHandle {
132 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
133 Ok(self.id.clone())
134 }
135
136 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
137 Ok(self.output.clone())
138 }
139
140 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
141 Ok(self.wait_for_exit.clone())
142 }
143
144 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
145 self.killed.store(true, Ordering::SeqCst);
146 self.signal_exit();
147 Ok(())
148 }
149
150 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
151 Ok(self.stopped_by_user.load(Ordering::SeqCst))
152 }
153}
154
155struct FakeThreadEnvironment {
156 handle: Rc<FakeTerminalHandle>,
157}
158
159impl crate::ThreadEnvironment for FakeThreadEnvironment {
160 fn create_terminal(
161 &self,
162 _command: String,
163 _cwd: Option<std::path::PathBuf>,
164 _output_byte_limit: Option<u64>,
165 _cx: &mut AsyncApp,
166 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
167 Task::ready(Ok(self.handle.clone() as Rc<dyn crate::TerminalHandle>))
168 }
169}
170
171/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
172struct MultiTerminalEnvironment {
173 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
174}
175
176impl MultiTerminalEnvironment {
177 fn new() -> Self {
178 Self {
179 handles: std::cell::RefCell::new(Vec::new()),
180 }
181 }
182
183 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
184 self.handles.borrow().clone()
185 }
186}
187
188impl crate::ThreadEnvironment for MultiTerminalEnvironment {
189 fn create_terminal(
190 &self,
191 _command: String,
192 _cwd: Option<std::path::PathBuf>,
193 _output_byte_limit: Option<u64>,
194 cx: &mut AsyncApp,
195 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
196 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
197 self.handles.borrow_mut().push(handle.clone());
198 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
199 }
200}
201
202fn always_allow_tools(cx: &mut TestAppContext) {
203 cx.update(|cx| {
204 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
205 settings.always_allow_tool_actions = true;
206 agent_settings::AgentSettings::override_global(settings, cx);
207 });
208}
209
210#[gpui::test]
211async fn test_echo(cx: &mut TestAppContext) {
212 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
213 let fake_model = model.as_fake();
214
215 let events = thread
216 .update(cx, |thread, cx| {
217 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
218 })
219 .unwrap();
220 cx.run_until_parked();
221 fake_model.send_last_completion_stream_text_chunk("Hello");
222 fake_model
223 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
224 fake_model.end_last_completion_stream();
225
226 let events = events.collect().await;
227 thread.update(cx, |thread, _cx| {
228 assert_eq!(
229 thread.last_message().unwrap().to_markdown(),
230 indoc! {"
231 ## Assistant
232
233 Hello
234 "}
235 )
236 });
237 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
238}
239
240#[gpui::test]
241async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
242 init_test(cx);
243 always_allow_tools(cx);
244
245 let fs = FakeFs::new(cx.executor());
246 let project = Project::test(fs, [], cx).await;
247
248 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
249 let environment = Rc::new(FakeThreadEnvironment {
250 handle: handle.clone(),
251 });
252
253 #[allow(clippy::arc_with_non_send_sync)]
254 let tool = Arc::new(crate::TerminalTool::new(project, environment));
255 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
256
257 let task = cx.update(|cx| {
258 tool.run(
259 crate::TerminalToolInput {
260 command: "sleep 1000".to_string(),
261 cd: ".".to_string(),
262 timeout_ms: Some(5),
263 },
264 event_stream,
265 cx,
266 )
267 });
268
269 let update = rx.expect_update_fields().await;
270 assert!(
271 update.content.iter().any(|blocks| {
272 blocks
273 .iter()
274 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
275 }),
276 "expected tool call update to include terminal content"
277 );
278
279 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
280
281 let deadline = std::time::Instant::now() + Duration::from_millis(500);
282 loop {
283 if let Some(result) = task_future.as_mut().now_or_never() {
284 let result = result.expect("terminal tool task should complete");
285
286 assert!(
287 handle.was_killed(),
288 "expected terminal handle to be killed on timeout"
289 );
290 assert!(
291 result.contains("partial output"),
292 "expected result to include terminal output, got: {result}"
293 );
294 return;
295 }
296
297 if std::time::Instant::now() >= deadline {
298 panic!("timed out waiting for terminal tool task to complete");
299 }
300
301 cx.run_until_parked();
302 cx.background_executor.timer(Duration::from_millis(1)).await;
303 }
304}
305
306#[gpui::test]
307#[ignore]
308async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
309 init_test(cx);
310 always_allow_tools(cx);
311
312 let fs = FakeFs::new(cx.executor());
313 let project = Project::test(fs, [], cx).await;
314
315 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
316 let environment = Rc::new(FakeThreadEnvironment {
317 handle: handle.clone(),
318 });
319
320 #[allow(clippy::arc_with_non_send_sync)]
321 let tool = Arc::new(crate::TerminalTool::new(project, environment));
322 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
323
324 let _task = cx.update(|cx| {
325 tool.run(
326 crate::TerminalToolInput {
327 command: "sleep 1000".to_string(),
328 cd: ".".to_string(),
329 timeout_ms: None,
330 },
331 event_stream,
332 cx,
333 )
334 });
335
336 let update = rx.expect_update_fields().await;
337 assert!(
338 update.content.iter().any(|blocks| {
339 blocks
340 .iter()
341 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
342 }),
343 "expected tool call update to include terminal content"
344 );
345
346 cx.background_executor
347 .timer(Duration::from_millis(25))
348 .await;
349
350 assert!(
351 !handle.was_killed(),
352 "did not expect terminal handle to be killed without a timeout"
353 );
354}
355
356#[gpui::test]
357async fn test_thinking(cx: &mut TestAppContext) {
358 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
359 let fake_model = model.as_fake();
360
361 let events = thread
362 .update(cx, |thread, cx| {
363 thread.send(
364 UserMessageId::new(),
365 [indoc! {"
366 Testing:
367
368 Generate a thinking step where you just think the word 'Think',
369 and have your final answer be 'Hello'
370 "}],
371 cx,
372 )
373 })
374 .unwrap();
375 cx.run_until_parked();
376 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
377 text: "Think".to_string(),
378 signature: None,
379 });
380 fake_model.send_last_completion_stream_text_chunk("Hello");
381 fake_model
382 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
383 fake_model.end_last_completion_stream();
384
385 let events = events.collect().await;
386 thread.update(cx, |thread, _cx| {
387 assert_eq!(
388 thread.last_message().unwrap().to_markdown(),
389 indoc! {"
390 ## Assistant
391
392 <think>Think</think>
393 Hello
394 "}
395 )
396 });
397 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
398}
399
400#[gpui::test]
401async fn test_system_prompt(cx: &mut TestAppContext) {
402 let ThreadTest {
403 model,
404 thread,
405 project_context,
406 ..
407 } = setup(cx, TestModel::Fake).await;
408 let fake_model = model.as_fake();
409
410 project_context.update(cx, |project_context, _cx| {
411 project_context.shell = "test-shell".into()
412 });
413 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
414 thread
415 .update(cx, |thread, cx| {
416 thread.send(UserMessageId::new(), ["abc"], cx)
417 })
418 .unwrap();
419 cx.run_until_parked();
420 let mut pending_completions = fake_model.pending_completions();
421 assert_eq!(
422 pending_completions.len(),
423 1,
424 "unexpected pending completions: {:?}",
425 pending_completions
426 );
427
428 let pending_completion = pending_completions.pop().unwrap();
429 assert_eq!(pending_completion.messages[0].role, Role::System);
430
431 let system_message = &pending_completion.messages[0];
432 let system_prompt = system_message.content[0].to_str().unwrap();
433 assert!(
434 system_prompt.contains("test-shell"),
435 "unexpected system message: {:?}",
436 system_message
437 );
438 assert!(
439 system_prompt.contains("## Fixing Diagnostics"),
440 "unexpected system message: {:?}",
441 system_message
442 );
443}
444
445#[gpui::test]
446async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
447 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
448 let fake_model = model.as_fake();
449
450 thread
451 .update(cx, |thread, cx| {
452 thread.send(UserMessageId::new(), ["abc"], cx)
453 })
454 .unwrap();
455 cx.run_until_parked();
456 let mut pending_completions = fake_model.pending_completions();
457 assert_eq!(
458 pending_completions.len(),
459 1,
460 "unexpected pending completions: {:?}",
461 pending_completions
462 );
463
464 let pending_completion = pending_completions.pop().unwrap();
465 assert_eq!(pending_completion.messages[0].role, Role::System);
466
467 let system_message = &pending_completion.messages[0];
468 let system_prompt = system_message.content[0].to_str().unwrap();
469 assert!(
470 !system_prompt.contains("## Tool Use"),
471 "unexpected system message: {:?}",
472 system_message
473 );
474 assert!(
475 !system_prompt.contains("## Fixing Diagnostics"),
476 "unexpected system message: {:?}",
477 system_message
478 );
479}
480
481#[gpui::test]
482async fn test_prompt_caching(cx: &mut TestAppContext) {
483 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
484 let fake_model = model.as_fake();
485
486 // Send initial user message and verify it's cached
487 thread
488 .update(cx, |thread, cx| {
489 thread.send(UserMessageId::new(), ["Message 1"], cx)
490 })
491 .unwrap();
492 cx.run_until_parked();
493
494 let completion = fake_model.pending_completions().pop().unwrap();
495 assert_eq!(
496 completion.messages[1..],
497 vec![LanguageModelRequestMessage {
498 role: Role::User,
499 content: vec!["Message 1".into()],
500 cache: true,
501 reasoning_details: None,
502 }]
503 );
504 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
505 "Response to Message 1".into(),
506 ));
507 fake_model.end_last_completion_stream();
508 cx.run_until_parked();
509
510 // Send another user message and verify only the latest is cached
511 thread
512 .update(cx, |thread, cx| {
513 thread.send(UserMessageId::new(), ["Message 2"], cx)
514 })
515 .unwrap();
516 cx.run_until_parked();
517
518 let completion = fake_model.pending_completions().pop().unwrap();
519 assert_eq!(
520 completion.messages[1..],
521 vec![
522 LanguageModelRequestMessage {
523 role: Role::User,
524 content: vec!["Message 1".into()],
525 cache: false,
526 reasoning_details: None,
527 },
528 LanguageModelRequestMessage {
529 role: Role::Assistant,
530 content: vec!["Response to Message 1".into()],
531 cache: false,
532 reasoning_details: None,
533 },
534 LanguageModelRequestMessage {
535 role: Role::User,
536 content: vec!["Message 2".into()],
537 cache: true,
538 reasoning_details: None,
539 }
540 ]
541 );
542 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
543 "Response to Message 2".into(),
544 ));
545 fake_model.end_last_completion_stream();
546 cx.run_until_parked();
547
548 // Simulate a tool call and verify that the latest tool result is cached
549 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
550 thread
551 .update(cx, |thread, cx| {
552 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
553 })
554 .unwrap();
555 cx.run_until_parked();
556
557 let tool_use = LanguageModelToolUse {
558 id: "tool_1".into(),
559 name: EchoTool::name().into(),
560 raw_input: json!({"text": "test"}).to_string(),
561 input: json!({"text": "test"}),
562 is_input_complete: true,
563 thought_signature: None,
564 };
565 fake_model
566 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
567 fake_model.end_last_completion_stream();
568 cx.run_until_parked();
569
570 let completion = fake_model.pending_completions().pop().unwrap();
571 let tool_result = LanguageModelToolResult {
572 tool_use_id: "tool_1".into(),
573 tool_name: EchoTool::name().into(),
574 is_error: false,
575 content: "test".into(),
576 output: Some("test".into()),
577 };
578 assert_eq!(
579 completion.messages[1..],
580 vec![
581 LanguageModelRequestMessage {
582 role: Role::User,
583 content: vec!["Message 1".into()],
584 cache: false,
585 reasoning_details: None,
586 },
587 LanguageModelRequestMessage {
588 role: Role::Assistant,
589 content: vec!["Response to Message 1".into()],
590 cache: false,
591 reasoning_details: None,
592 },
593 LanguageModelRequestMessage {
594 role: Role::User,
595 content: vec!["Message 2".into()],
596 cache: false,
597 reasoning_details: None,
598 },
599 LanguageModelRequestMessage {
600 role: Role::Assistant,
601 content: vec!["Response to Message 2".into()],
602 cache: false,
603 reasoning_details: None,
604 },
605 LanguageModelRequestMessage {
606 role: Role::User,
607 content: vec!["Use the echo tool".into()],
608 cache: false,
609 reasoning_details: None,
610 },
611 LanguageModelRequestMessage {
612 role: Role::Assistant,
613 content: vec![MessageContent::ToolUse(tool_use)],
614 cache: false,
615 reasoning_details: None,
616 },
617 LanguageModelRequestMessage {
618 role: Role::User,
619 content: vec![MessageContent::ToolResult(tool_result)],
620 cache: true,
621 reasoning_details: None,
622 }
623 ]
624 );
625}
626
627#[gpui::test]
628#[cfg_attr(not(feature = "e2e"), ignore)]
629async fn test_basic_tool_calls(cx: &mut TestAppContext) {
630 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
631
632 // Test a tool call that's likely to complete *before* streaming stops.
633 let events = thread
634 .update(cx, |thread, cx| {
635 thread.add_tool(EchoTool);
636 thread.send(
637 UserMessageId::new(),
638 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
639 cx,
640 )
641 })
642 .unwrap()
643 .collect()
644 .await;
645 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
646
647 // Test a tool calls that's likely to complete *after* streaming stops.
648 let events = thread
649 .update(cx, |thread, cx| {
650 thread.remove_tool(&EchoTool::name());
651 thread.add_tool(DelayTool);
652 thread.send(
653 UserMessageId::new(),
654 [
655 "Now call the delay tool with 200ms.",
656 "When the timer goes off, then you echo the output of the tool.",
657 ],
658 cx,
659 )
660 })
661 .unwrap()
662 .collect()
663 .await;
664 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
665 thread.update(cx, |thread, _cx| {
666 assert!(
667 thread
668 .last_message()
669 .unwrap()
670 .as_agent_message()
671 .unwrap()
672 .content
673 .iter()
674 .any(|content| {
675 if let AgentMessageContent::Text(text) = content {
676 text.contains("Ding")
677 } else {
678 false
679 }
680 }),
681 "{}",
682 thread.to_markdown()
683 );
684 });
685}
686
687#[gpui::test]
688#[cfg_attr(not(feature = "e2e"), ignore)]
689async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
690 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
691
692 // Test a tool call that's likely to complete *before* streaming stops.
693 let mut events = thread
694 .update(cx, |thread, cx| {
695 thread.add_tool(WordListTool);
696 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
697 })
698 .unwrap();
699
700 let mut saw_partial_tool_use = false;
701 while let Some(event) = events.next().await {
702 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
703 thread.update(cx, |thread, _cx| {
704 // Look for a tool use in the thread's last message
705 let message = thread.last_message().unwrap();
706 let agent_message = message.as_agent_message().unwrap();
707 let last_content = agent_message.content.last().unwrap();
708 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
709 assert_eq!(last_tool_use.name.as_ref(), "word_list");
710 if tool_call.status == acp::ToolCallStatus::Pending {
711 if !last_tool_use.is_input_complete
712 && last_tool_use.input.get("g").is_none()
713 {
714 saw_partial_tool_use = true;
715 }
716 } else {
717 last_tool_use
718 .input
719 .get("a")
720 .expect("'a' has streamed because input is now complete");
721 last_tool_use
722 .input
723 .get("g")
724 .expect("'g' has streamed because input is now complete");
725 }
726 } else {
727 panic!("last content should be a tool use");
728 }
729 });
730 }
731 }
732
733 assert!(
734 saw_partial_tool_use,
735 "should see at least one partially streamed tool use in the history"
736 );
737}
738
739#[gpui::test]
740async fn test_tool_authorization(cx: &mut TestAppContext) {
741 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
742 let fake_model = model.as_fake();
743
744 let mut events = thread
745 .update(cx, |thread, cx| {
746 thread.add_tool(ToolRequiringPermission);
747 thread.send(UserMessageId::new(), ["abc"], cx)
748 })
749 .unwrap();
750 cx.run_until_parked();
751 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
752 LanguageModelToolUse {
753 id: "tool_id_1".into(),
754 name: ToolRequiringPermission::name().into(),
755 raw_input: "{}".into(),
756 input: json!({}),
757 is_input_complete: true,
758 thought_signature: None,
759 },
760 ));
761 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
762 LanguageModelToolUse {
763 id: "tool_id_2".into(),
764 name: ToolRequiringPermission::name().into(),
765 raw_input: "{}".into(),
766 input: json!({}),
767 is_input_complete: true,
768 thought_signature: None,
769 },
770 ));
771 fake_model.end_last_completion_stream();
772 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
773 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
774
775 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
776 tool_call_auth_1
777 .response
778 .send(acp::PermissionOptionId::new("allow"))
779 .unwrap();
780 cx.run_until_parked();
781
782 // Reject the second - send "deny" option_id directly since Deny is now a button
783 tool_call_auth_2
784 .response
785 .send(acp::PermissionOptionId::new("deny"))
786 .unwrap();
787 cx.run_until_parked();
788
789 let completion = fake_model.pending_completions().pop().unwrap();
790 let message = completion.messages.last().unwrap();
791 assert_eq!(
792 message.content,
793 vec![
794 language_model::MessageContent::ToolResult(LanguageModelToolResult {
795 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
796 tool_name: ToolRequiringPermission::name().into(),
797 is_error: false,
798 content: "Allowed".into(),
799 output: Some("Allowed".into())
800 }),
801 language_model::MessageContent::ToolResult(LanguageModelToolResult {
802 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
803 tool_name: ToolRequiringPermission::name().into(),
804 is_error: true,
805 content: "Permission to run tool denied by user".into(),
806 output: Some("Permission to run tool denied by user".into())
807 })
808 ]
809 );
810
811 // Simulate yet another tool call.
812 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
813 LanguageModelToolUse {
814 id: "tool_id_3".into(),
815 name: ToolRequiringPermission::name().into(),
816 raw_input: "{}".into(),
817 input: json!({}),
818 is_input_complete: true,
819 thought_signature: None,
820 },
821 ));
822 fake_model.end_last_completion_stream();
823
824 // Respond by always allowing tools - send transformed option_id
825 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
826 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
827 tool_call_auth_3
828 .response
829 .send(acp::PermissionOptionId::new(
830 "always_allow:tool_requiring_permission",
831 ))
832 .unwrap();
833 cx.run_until_parked();
834 let completion = fake_model.pending_completions().pop().unwrap();
835 let message = completion.messages.last().unwrap();
836 assert_eq!(
837 message.content,
838 vec![language_model::MessageContent::ToolResult(
839 LanguageModelToolResult {
840 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
841 tool_name: ToolRequiringPermission::name().into(),
842 is_error: false,
843 content: "Allowed".into(),
844 output: Some("Allowed".into())
845 }
846 )]
847 );
848
849 // Simulate a final tool call, ensuring we don't trigger authorization.
850 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
851 LanguageModelToolUse {
852 id: "tool_id_4".into(),
853 name: ToolRequiringPermission::name().into(),
854 raw_input: "{}".into(),
855 input: json!({}),
856 is_input_complete: true,
857 thought_signature: None,
858 },
859 ));
860 fake_model.end_last_completion_stream();
861 cx.run_until_parked();
862 let completion = fake_model.pending_completions().pop().unwrap();
863 let message = completion.messages.last().unwrap();
864 assert_eq!(
865 message.content,
866 vec![language_model::MessageContent::ToolResult(
867 LanguageModelToolResult {
868 tool_use_id: "tool_id_4".into(),
869 tool_name: ToolRequiringPermission::name().into(),
870 is_error: false,
871 content: "Allowed".into(),
872 output: Some("Allowed".into())
873 }
874 )]
875 );
876}
877
878#[gpui::test]
879async fn test_tool_hallucination(cx: &mut TestAppContext) {
880 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
881 let fake_model = model.as_fake();
882
883 let mut events = thread
884 .update(cx, |thread, cx| {
885 thread.send(UserMessageId::new(), ["abc"], cx)
886 })
887 .unwrap();
888 cx.run_until_parked();
889 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
890 LanguageModelToolUse {
891 id: "tool_id_1".into(),
892 name: "nonexistent_tool".into(),
893 raw_input: "{}".into(),
894 input: json!({}),
895 is_input_complete: true,
896 thought_signature: None,
897 },
898 ));
899 fake_model.end_last_completion_stream();
900
901 let tool_call = expect_tool_call(&mut events).await;
902 assert_eq!(tool_call.title, "nonexistent_tool");
903 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
904 let update = expect_tool_call_update_fields(&mut events).await;
905 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
906}
907
908#[gpui::test]
909async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
910 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
911 let fake_model = model.as_fake();
912
913 let events = thread
914 .update(cx, |thread, cx| {
915 thread.add_tool(EchoTool);
916 thread.send(UserMessageId::new(), ["abc"], cx)
917 })
918 .unwrap();
919 cx.run_until_parked();
920 let tool_use = LanguageModelToolUse {
921 id: "tool_id_1".into(),
922 name: EchoTool::name().into(),
923 raw_input: "{}".into(),
924 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
925 is_input_complete: true,
926 thought_signature: None,
927 };
928 fake_model
929 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
930 fake_model.end_last_completion_stream();
931
932 cx.run_until_parked();
933 let completion = fake_model.pending_completions().pop().unwrap();
934 let tool_result = LanguageModelToolResult {
935 tool_use_id: "tool_id_1".into(),
936 tool_name: EchoTool::name().into(),
937 is_error: false,
938 content: "def".into(),
939 output: Some("def".into()),
940 };
941 assert_eq!(
942 completion.messages[1..],
943 vec![
944 LanguageModelRequestMessage {
945 role: Role::User,
946 content: vec!["abc".into()],
947 cache: false,
948 reasoning_details: None,
949 },
950 LanguageModelRequestMessage {
951 role: Role::Assistant,
952 content: vec![MessageContent::ToolUse(tool_use.clone())],
953 cache: false,
954 reasoning_details: None,
955 },
956 LanguageModelRequestMessage {
957 role: Role::User,
958 content: vec![MessageContent::ToolResult(tool_result.clone())],
959 cache: true,
960 reasoning_details: None,
961 },
962 ]
963 );
964
965 // Simulate reaching tool use limit.
966 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
967 fake_model.end_last_completion_stream();
968 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
969 assert!(
970 last_event
971 .unwrap_err()
972 .is::<language_model::ToolUseLimitReachedError>()
973 );
974
975 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
976 cx.run_until_parked();
977 let completion = fake_model.pending_completions().pop().unwrap();
978 assert_eq!(
979 completion.messages[1..],
980 vec![
981 LanguageModelRequestMessage {
982 role: Role::User,
983 content: vec!["abc".into()],
984 cache: false,
985 reasoning_details: None,
986 },
987 LanguageModelRequestMessage {
988 role: Role::Assistant,
989 content: vec![MessageContent::ToolUse(tool_use)],
990 cache: false,
991 reasoning_details: None,
992 },
993 LanguageModelRequestMessage {
994 role: Role::User,
995 content: vec![MessageContent::ToolResult(tool_result)],
996 cache: false,
997 reasoning_details: None,
998 },
999 LanguageModelRequestMessage {
1000 role: Role::User,
1001 content: vec!["Continue where you left off".into()],
1002 cache: true,
1003 reasoning_details: None,
1004 }
1005 ]
1006 );
1007
1008 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
1009 fake_model.end_last_completion_stream();
1010 events.collect::<Vec<_>>().await;
1011 thread.read_with(cx, |thread, _cx| {
1012 assert_eq!(
1013 thread.last_message().unwrap().to_markdown(),
1014 indoc! {"
1015 ## Assistant
1016
1017 Done
1018 "}
1019 )
1020 });
1021}
1022
1023#[gpui::test]
1024async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
1025 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1026 let fake_model = model.as_fake();
1027
1028 let events = thread
1029 .update(cx, |thread, cx| {
1030 thread.add_tool(EchoTool);
1031 thread.send(UserMessageId::new(), ["abc"], cx)
1032 })
1033 .unwrap();
1034 cx.run_until_parked();
1035
1036 let tool_use = LanguageModelToolUse {
1037 id: "tool_id_1".into(),
1038 name: EchoTool::name().into(),
1039 raw_input: "{}".into(),
1040 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
1041 is_input_complete: true,
1042 thought_signature: None,
1043 };
1044 let tool_result = LanguageModelToolResult {
1045 tool_use_id: "tool_id_1".into(),
1046 tool_name: EchoTool::name().into(),
1047 is_error: false,
1048 content: "def".into(),
1049 output: Some("def".into()),
1050 };
1051 fake_model
1052 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
1053 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
1054 fake_model.end_last_completion_stream();
1055 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
1056 assert!(
1057 last_event
1058 .unwrap_err()
1059 .is::<language_model::ToolUseLimitReachedError>()
1060 );
1061
1062 thread
1063 .update(cx, |thread, cx| {
1064 thread.send(UserMessageId::new(), vec!["ghi"], cx)
1065 })
1066 .unwrap();
1067 cx.run_until_parked();
1068 let completion = fake_model.pending_completions().pop().unwrap();
1069 assert_eq!(
1070 completion.messages[1..],
1071 vec![
1072 LanguageModelRequestMessage {
1073 role: Role::User,
1074 content: vec!["abc".into()],
1075 cache: false,
1076 reasoning_details: None,
1077 },
1078 LanguageModelRequestMessage {
1079 role: Role::Assistant,
1080 content: vec![MessageContent::ToolUse(tool_use)],
1081 cache: false,
1082 reasoning_details: None,
1083 },
1084 LanguageModelRequestMessage {
1085 role: Role::User,
1086 content: vec![MessageContent::ToolResult(tool_result)],
1087 cache: false,
1088 reasoning_details: None,
1089 },
1090 LanguageModelRequestMessage {
1091 role: Role::User,
1092 content: vec!["ghi".into()],
1093 cache: true,
1094 reasoning_details: None,
1095 }
1096 ]
1097 );
1098}
1099
1100async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
1101 let event = events
1102 .next()
1103 .await
1104 .expect("no tool call authorization event received")
1105 .unwrap();
1106 match event {
1107 ThreadEvent::ToolCall(tool_call) => tool_call,
1108 event => {
1109 panic!("Unexpected event {event:?}");
1110 }
1111 }
1112}
1113
1114async fn expect_tool_call_update_fields(
1115 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1116) -> acp::ToolCallUpdate {
1117 let event = events
1118 .next()
1119 .await
1120 .expect("no tool call authorization event received")
1121 .unwrap();
1122 match event {
1123 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
1124 event => {
1125 panic!("Unexpected event {event:?}");
1126 }
1127 }
1128}
1129
1130async fn next_tool_call_authorization(
1131 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1132) -> ToolCallAuthorization {
1133 loop {
1134 let event = events
1135 .next()
1136 .await
1137 .expect("no tool call authorization event received")
1138 .unwrap();
1139 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1140 let permission_kinds = tool_call_authorization
1141 .options
1142 .iter()
1143 .map(|o| o.kind)
1144 .collect::<Vec<_>>();
1145 // Only 2 options now: AllowAlways (for tool) and AllowOnce (granularity only)
1146 // Deny is handled by the UI buttons, not as a separate option
1147 assert_eq!(
1148 permission_kinds,
1149 vec![
1150 acp::PermissionOptionKind::AllowAlways,
1151 acp::PermissionOptionKind::AllowOnce,
1152 ]
1153 );
1154 return tool_call_authorization;
1155 }
1156 }
1157}
1158
1159#[gpui::test]
1160#[cfg_attr(not(feature = "e2e"), ignore)]
1161async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1162 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1163
1164 // Test concurrent tool calls with different delay times
1165 let events = thread
1166 .update(cx, |thread, cx| {
1167 thread.add_tool(DelayTool);
1168 thread.send(
1169 UserMessageId::new(),
1170 [
1171 "Call the delay tool twice in the same message.",
1172 "Once with 100ms. Once with 300ms.",
1173 "When both timers are complete, describe the outputs.",
1174 ],
1175 cx,
1176 )
1177 })
1178 .unwrap()
1179 .collect()
1180 .await;
1181
1182 let stop_reasons = stop_events(events);
1183 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1184
1185 thread.update(cx, |thread, _cx| {
1186 let last_message = thread.last_message().unwrap();
1187 let agent_message = last_message.as_agent_message().unwrap();
1188 let text = agent_message
1189 .content
1190 .iter()
1191 .filter_map(|content| {
1192 if let AgentMessageContent::Text(text) = content {
1193 Some(text.as_str())
1194 } else {
1195 None
1196 }
1197 })
1198 .collect::<String>();
1199
1200 assert!(text.contains("Ding"));
1201 });
1202}
1203
1204#[gpui::test]
1205async fn test_profiles(cx: &mut TestAppContext) {
1206 let ThreadTest {
1207 model, thread, fs, ..
1208 } = setup(cx, TestModel::Fake).await;
1209 let fake_model = model.as_fake();
1210
1211 thread.update(cx, |thread, _cx| {
1212 thread.add_tool(DelayTool);
1213 thread.add_tool(EchoTool);
1214 thread.add_tool(InfiniteTool);
1215 });
1216
1217 // Override profiles and wait for settings to be loaded.
1218 fs.insert_file(
1219 paths::settings_file(),
1220 json!({
1221 "agent": {
1222 "profiles": {
1223 "test-1": {
1224 "name": "Test Profile 1",
1225 "tools": {
1226 EchoTool::name(): true,
1227 DelayTool::name(): true,
1228 }
1229 },
1230 "test-2": {
1231 "name": "Test Profile 2",
1232 "tools": {
1233 InfiniteTool::name(): true,
1234 }
1235 }
1236 }
1237 }
1238 })
1239 .to_string()
1240 .into_bytes(),
1241 )
1242 .await;
1243 cx.run_until_parked();
1244
1245 // Test that test-1 profile (default) has echo and delay tools
1246 thread
1247 .update(cx, |thread, cx| {
1248 thread.set_profile(AgentProfileId("test-1".into()), cx);
1249 thread.send(UserMessageId::new(), ["test"], cx)
1250 })
1251 .unwrap();
1252 cx.run_until_parked();
1253
1254 let mut pending_completions = fake_model.pending_completions();
1255 assert_eq!(pending_completions.len(), 1);
1256 let completion = pending_completions.pop().unwrap();
1257 let tool_names: Vec<String> = completion
1258 .tools
1259 .iter()
1260 .map(|tool| tool.name.clone())
1261 .collect();
1262 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
1263 fake_model.end_last_completion_stream();
1264
1265 // Switch to test-2 profile, and verify that it has only the infinite tool.
1266 thread
1267 .update(cx, |thread, cx| {
1268 thread.set_profile(AgentProfileId("test-2".into()), cx);
1269 thread.send(UserMessageId::new(), ["test2"], cx)
1270 })
1271 .unwrap();
1272 cx.run_until_parked();
1273 let mut pending_completions = fake_model.pending_completions();
1274 assert_eq!(pending_completions.len(), 1);
1275 let completion = pending_completions.pop().unwrap();
1276 let tool_names: Vec<String> = completion
1277 .tools
1278 .iter()
1279 .map(|tool| tool.name.clone())
1280 .collect();
1281 assert_eq!(tool_names, vec![InfiniteTool::name()]);
1282}
1283
1284#[gpui::test]
1285async fn test_mcp_tools(cx: &mut TestAppContext) {
1286 let ThreadTest {
1287 model,
1288 thread,
1289 context_server_store,
1290 fs,
1291 ..
1292 } = setup(cx, TestModel::Fake).await;
1293 let fake_model = model.as_fake();
1294
1295 // Override profiles and wait for settings to be loaded.
1296 fs.insert_file(
1297 paths::settings_file(),
1298 json!({
1299 "agent": {
1300 "always_allow_tool_actions": true,
1301 "profiles": {
1302 "test": {
1303 "name": "Test Profile",
1304 "enable_all_context_servers": true,
1305 "tools": {
1306 EchoTool::name(): true,
1307 }
1308 },
1309 }
1310 }
1311 })
1312 .to_string()
1313 .into_bytes(),
1314 )
1315 .await;
1316 cx.run_until_parked();
1317 thread.update(cx, |thread, cx| {
1318 thread.set_profile(AgentProfileId("test".into()), cx)
1319 });
1320
1321 let mut mcp_tool_calls = setup_context_server(
1322 "test_server",
1323 vec![context_server::types::Tool {
1324 name: "echo".into(),
1325 description: None,
1326 input_schema: serde_json::to_value(EchoTool::input_schema(
1327 LanguageModelToolSchemaFormat::JsonSchema,
1328 ))
1329 .unwrap(),
1330 output_schema: None,
1331 annotations: None,
1332 }],
1333 &context_server_store,
1334 cx,
1335 );
1336
1337 let events = thread.update(cx, |thread, cx| {
1338 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1339 });
1340 cx.run_until_parked();
1341
1342 // Simulate the model calling the MCP tool.
1343 let completion = fake_model.pending_completions().pop().unwrap();
1344 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1345 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1346 LanguageModelToolUse {
1347 id: "tool_1".into(),
1348 name: "echo".into(),
1349 raw_input: json!({"text": "test"}).to_string(),
1350 input: json!({"text": "test"}),
1351 is_input_complete: true,
1352 thought_signature: None,
1353 },
1354 ));
1355 fake_model.end_last_completion_stream();
1356 cx.run_until_parked();
1357
1358 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1359 assert_eq!(tool_call_params.name, "echo");
1360 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1361 tool_call_response
1362 .send(context_server::types::CallToolResponse {
1363 content: vec![context_server::types::ToolResponseContent::Text {
1364 text: "test".into(),
1365 }],
1366 is_error: None,
1367 meta: None,
1368 structured_content: None,
1369 })
1370 .unwrap();
1371 cx.run_until_parked();
1372
1373 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1374 fake_model.send_last_completion_stream_text_chunk("Done!");
1375 fake_model.end_last_completion_stream();
1376 events.collect::<Vec<_>>().await;
1377
1378 // Send again after adding the echo tool, ensuring the name collision is resolved.
1379 let events = thread.update(cx, |thread, cx| {
1380 thread.add_tool(EchoTool);
1381 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1382 });
1383 cx.run_until_parked();
1384 let completion = fake_model.pending_completions().pop().unwrap();
1385 assert_eq!(
1386 tool_names_for_completion(&completion),
1387 vec!["echo", "test_server_echo"]
1388 );
1389 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1390 LanguageModelToolUse {
1391 id: "tool_2".into(),
1392 name: "test_server_echo".into(),
1393 raw_input: json!({"text": "mcp"}).to_string(),
1394 input: json!({"text": "mcp"}),
1395 is_input_complete: true,
1396 thought_signature: None,
1397 },
1398 ));
1399 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1400 LanguageModelToolUse {
1401 id: "tool_3".into(),
1402 name: "echo".into(),
1403 raw_input: json!({"text": "native"}).to_string(),
1404 input: json!({"text": "native"}),
1405 is_input_complete: true,
1406 thought_signature: None,
1407 },
1408 ));
1409 fake_model.end_last_completion_stream();
1410 cx.run_until_parked();
1411
1412 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1413 assert_eq!(tool_call_params.name, "echo");
1414 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1415 tool_call_response
1416 .send(context_server::types::CallToolResponse {
1417 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1418 is_error: None,
1419 meta: None,
1420 structured_content: None,
1421 })
1422 .unwrap();
1423 cx.run_until_parked();
1424
1425 // Ensure the tool results were inserted with the correct names.
1426 let completion = fake_model.pending_completions().pop().unwrap();
1427 assert_eq!(
1428 completion.messages.last().unwrap().content,
1429 vec![
1430 MessageContent::ToolResult(LanguageModelToolResult {
1431 tool_use_id: "tool_3".into(),
1432 tool_name: "echo".into(),
1433 is_error: false,
1434 content: "native".into(),
1435 output: Some("native".into()),
1436 },),
1437 MessageContent::ToolResult(LanguageModelToolResult {
1438 tool_use_id: "tool_2".into(),
1439 tool_name: "test_server_echo".into(),
1440 is_error: false,
1441 content: "mcp".into(),
1442 output: Some("mcp".into()),
1443 },),
1444 ]
1445 );
1446 fake_model.end_last_completion_stream();
1447 events.collect::<Vec<_>>().await;
1448}
1449
1450#[gpui::test]
1451async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1452 let ThreadTest {
1453 model,
1454 thread,
1455 context_server_store,
1456 fs,
1457 ..
1458 } = setup(cx, TestModel::Fake).await;
1459 let fake_model = model.as_fake();
1460
1461 // Set up a profile with all tools enabled
1462 fs.insert_file(
1463 paths::settings_file(),
1464 json!({
1465 "agent": {
1466 "profiles": {
1467 "test": {
1468 "name": "Test Profile",
1469 "enable_all_context_servers": true,
1470 "tools": {
1471 EchoTool::name(): true,
1472 DelayTool::name(): true,
1473 WordListTool::name(): true,
1474 ToolRequiringPermission::name(): true,
1475 InfiniteTool::name(): true,
1476 }
1477 },
1478 }
1479 }
1480 })
1481 .to_string()
1482 .into_bytes(),
1483 )
1484 .await;
1485 cx.run_until_parked();
1486
1487 thread.update(cx, |thread, cx| {
1488 thread.set_profile(AgentProfileId("test".into()), cx);
1489 thread.add_tool(EchoTool);
1490 thread.add_tool(DelayTool);
1491 thread.add_tool(WordListTool);
1492 thread.add_tool(ToolRequiringPermission);
1493 thread.add_tool(InfiniteTool);
1494 });
1495
1496 // Set up multiple context servers with some overlapping tool names
1497 let _server1_calls = setup_context_server(
1498 "xxx",
1499 vec![
1500 context_server::types::Tool {
1501 name: "echo".into(), // Conflicts with native EchoTool
1502 description: None,
1503 input_schema: serde_json::to_value(EchoTool::input_schema(
1504 LanguageModelToolSchemaFormat::JsonSchema,
1505 ))
1506 .unwrap(),
1507 output_schema: None,
1508 annotations: None,
1509 },
1510 context_server::types::Tool {
1511 name: "unique_tool_1".into(),
1512 description: None,
1513 input_schema: json!({"type": "object", "properties": {}}),
1514 output_schema: None,
1515 annotations: None,
1516 },
1517 ],
1518 &context_server_store,
1519 cx,
1520 );
1521
1522 let _server2_calls = setup_context_server(
1523 "yyy",
1524 vec![
1525 context_server::types::Tool {
1526 name: "echo".into(), // Also conflicts with native EchoTool
1527 description: None,
1528 input_schema: serde_json::to_value(EchoTool::input_schema(
1529 LanguageModelToolSchemaFormat::JsonSchema,
1530 ))
1531 .unwrap(),
1532 output_schema: None,
1533 annotations: None,
1534 },
1535 context_server::types::Tool {
1536 name: "unique_tool_2".into(),
1537 description: None,
1538 input_schema: json!({"type": "object", "properties": {}}),
1539 output_schema: None,
1540 annotations: None,
1541 },
1542 context_server::types::Tool {
1543 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1544 description: None,
1545 input_schema: json!({"type": "object", "properties": {}}),
1546 output_schema: None,
1547 annotations: None,
1548 },
1549 context_server::types::Tool {
1550 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1551 description: None,
1552 input_schema: json!({"type": "object", "properties": {}}),
1553 output_schema: None,
1554 annotations: None,
1555 },
1556 ],
1557 &context_server_store,
1558 cx,
1559 );
1560 let _server3_calls = setup_context_server(
1561 "zzz",
1562 vec![
1563 context_server::types::Tool {
1564 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1565 description: None,
1566 input_schema: json!({"type": "object", "properties": {}}),
1567 output_schema: None,
1568 annotations: None,
1569 },
1570 context_server::types::Tool {
1571 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1572 description: None,
1573 input_schema: json!({"type": "object", "properties": {}}),
1574 output_schema: None,
1575 annotations: None,
1576 },
1577 context_server::types::Tool {
1578 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1579 description: None,
1580 input_schema: json!({"type": "object", "properties": {}}),
1581 output_schema: None,
1582 annotations: None,
1583 },
1584 ],
1585 &context_server_store,
1586 cx,
1587 );
1588
1589 thread
1590 .update(cx, |thread, cx| {
1591 thread.send(UserMessageId::new(), ["Go"], cx)
1592 })
1593 .unwrap();
1594 cx.run_until_parked();
1595 let completion = fake_model.pending_completions().pop().unwrap();
1596 assert_eq!(
1597 tool_names_for_completion(&completion),
1598 vec![
1599 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1600 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1601 "delay",
1602 "echo",
1603 "infinite",
1604 "tool_requiring_permission",
1605 "unique_tool_1",
1606 "unique_tool_2",
1607 "word_list",
1608 "xxx_echo",
1609 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1610 "yyy_echo",
1611 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1612 ]
1613 );
1614}
1615
1616#[gpui::test]
1617#[cfg_attr(not(feature = "e2e"), ignore)]
1618async fn test_cancellation(cx: &mut TestAppContext) {
1619 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1620
1621 let mut events = thread
1622 .update(cx, |thread, cx| {
1623 thread.add_tool(InfiniteTool);
1624 thread.add_tool(EchoTool);
1625 thread.send(
1626 UserMessageId::new(),
1627 ["Call the echo tool, then call the infinite tool, then explain their output"],
1628 cx,
1629 )
1630 })
1631 .unwrap();
1632
1633 // Wait until both tools are called.
1634 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1635 let mut echo_id = None;
1636 let mut echo_completed = false;
1637 while let Some(event) = events.next().await {
1638 match event.unwrap() {
1639 ThreadEvent::ToolCall(tool_call) => {
1640 assert_eq!(tool_call.title, expected_tools.remove(0));
1641 if tool_call.title == "Echo" {
1642 echo_id = Some(tool_call.tool_call_id);
1643 }
1644 }
1645 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1646 acp::ToolCallUpdate {
1647 tool_call_id,
1648 fields:
1649 acp::ToolCallUpdateFields {
1650 status: Some(acp::ToolCallStatus::Completed),
1651 ..
1652 },
1653 ..
1654 },
1655 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1656 echo_completed = true;
1657 }
1658 _ => {}
1659 }
1660
1661 if expected_tools.is_empty() && echo_completed {
1662 break;
1663 }
1664 }
1665
1666 // Cancel the current send and ensure that the event stream is closed, even
1667 // if one of the tools is still running.
1668 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1669 let events = events.collect::<Vec<_>>().await;
1670 let last_event = events.last();
1671 assert!(
1672 matches!(
1673 last_event,
1674 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1675 ),
1676 "unexpected event {last_event:?}"
1677 );
1678
1679 // Ensure we can still send a new message after cancellation.
1680 let events = thread
1681 .update(cx, |thread, cx| {
1682 thread.send(
1683 UserMessageId::new(),
1684 ["Testing: reply with 'Hello' then stop."],
1685 cx,
1686 )
1687 })
1688 .unwrap()
1689 .collect::<Vec<_>>()
1690 .await;
1691 thread.update(cx, |thread, _cx| {
1692 let message = thread.last_message().unwrap();
1693 let agent_message = message.as_agent_message().unwrap();
1694 assert_eq!(
1695 agent_message.content,
1696 vec![AgentMessageContent::Text("Hello".to_string())]
1697 );
1698 });
1699 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1700}
1701
1702#[gpui::test]
1703async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1704 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1705 always_allow_tools(cx);
1706 let fake_model = model.as_fake();
1707
1708 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1709 let environment = Rc::new(FakeThreadEnvironment {
1710 handle: handle.clone(),
1711 });
1712
1713 let mut events = thread
1714 .update(cx, |thread, cx| {
1715 thread.add_tool(crate::TerminalTool::new(
1716 thread.project().clone(),
1717 environment,
1718 ));
1719 thread.send(UserMessageId::new(), ["run a command"], cx)
1720 })
1721 .unwrap();
1722
1723 cx.run_until_parked();
1724
1725 // Simulate the model calling the terminal tool
1726 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1727 LanguageModelToolUse {
1728 id: "terminal_tool_1".into(),
1729 name: "terminal".into(),
1730 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1731 input: json!({"command": "sleep 1000", "cd": "."}),
1732 is_input_complete: true,
1733 thought_signature: None,
1734 },
1735 ));
1736 fake_model.end_last_completion_stream();
1737
1738 // Wait for the terminal tool to start running
1739 wait_for_terminal_tool_started(&mut events, cx).await;
1740
1741 // Cancel the thread while the terminal is running
1742 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1743
1744 // Collect remaining events, driving the executor to let cancellation complete
1745 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1746
1747 // Verify the terminal was killed
1748 assert!(
1749 handle.was_killed(),
1750 "expected terminal handle to be killed on cancellation"
1751 );
1752
1753 // Verify we got a cancellation stop event
1754 assert_eq!(
1755 stop_events(remaining_events),
1756 vec![acp::StopReason::Cancelled],
1757 );
1758
1759 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1760 thread.update(cx, |thread, _cx| {
1761 let message = thread.last_message().unwrap();
1762 let agent_message = message.as_agent_message().unwrap();
1763
1764 let tool_use = agent_message
1765 .content
1766 .iter()
1767 .find_map(|content| match content {
1768 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1769 _ => None,
1770 })
1771 .expect("expected tool use in agent message");
1772
1773 let tool_result = agent_message
1774 .tool_results
1775 .get(&tool_use.id)
1776 .expect("expected tool result");
1777
1778 let result_text = match &tool_result.content {
1779 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
1780 _ => panic!("expected text content in tool result"),
1781 };
1782
1783 // "partial output" comes from FakeTerminalHandle's output field
1784 assert!(
1785 result_text.contains("partial output"),
1786 "expected tool result to contain terminal output, got: {result_text}"
1787 );
1788 // Match the actual format from process_content in terminal_tool.rs
1789 assert!(
1790 result_text.contains("The user stopped this command"),
1791 "expected tool result to indicate user stopped, got: {result_text}"
1792 );
1793 });
1794
1795 // Verify we can send a new message after cancellation
1796 verify_thread_recovery(&thread, &fake_model, cx).await;
1797}
1798
1799#[gpui::test]
1800async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
1801 // This test verifies that tools which properly handle cancellation via
1802 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
1803 // to cancellation and report that they were cancelled.
1804 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1805 always_allow_tools(cx);
1806 let fake_model = model.as_fake();
1807
1808 let (tool, was_cancelled) = CancellationAwareTool::new();
1809
1810 let mut events = thread
1811 .update(cx, |thread, cx| {
1812 thread.add_tool(tool);
1813 thread.send(
1814 UserMessageId::new(),
1815 ["call the cancellation aware tool"],
1816 cx,
1817 )
1818 })
1819 .unwrap();
1820
1821 cx.run_until_parked();
1822
1823 // Simulate the model calling the cancellation-aware tool
1824 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1825 LanguageModelToolUse {
1826 id: "cancellation_aware_1".into(),
1827 name: "cancellation_aware".into(),
1828 raw_input: r#"{}"#.into(),
1829 input: json!({}),
1830 is_input_complete: true,
1831 thought_signature: None,
1832 },
1833 ));
1834 fake_model.end_last_completion_stream();
1835
1836 cx.run_until_parked();
1837
1838 // Wait for the tool call to be reported
1839 let mut tool_started = false;
1840 let deadline = cx.executor().num_cpus() * 100;
1841 for _ in 0..deadline {
1842 cx.run_until_parked();
1843
1844 while let Some(Some(event)) = events.next().now_or_never() {
1845 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
1846 if tool_call.title == "Cancellation Aware Tool" {
1847 tool_started = true;
1848 break;
1849 }
1850 }
1851 }
1852
1853 if tool_started {
1854 break;
1855 }
1856
1857 cx.background_executor
1858 .timer(Duration::from_millis(10))
1859 .await;
1860 }
1861 assert!(tool_started, "expected cancellation aware tool to start");
1862
1863 // Cancel the thread and wait for it to complete
1864 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
1865
1866 // The cancel task should complete promptly because the tool handles cancellation
1867 let timeout = cx.background_executor.timer(Duration::from_secs(5));
1868 futures::select! {
1869 _ = cancel_task.fuse() => {}
1870 _ = timeout.fuse() => {
1871 panic!("cancel task timed out - tool did not respond to cancellation");
1872 }
1873 }
1874
1875 // Verify the tool detected cancellation via its flag
1876 assert!(
1877 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
1878 "tool should have detected cancellation via event_stream.cancelled_by_user()"
1879 );
1880
1881 // Collect remaining events
1882 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1883
1884 // Verify we got a cancellation stop event
1885 assert_eq!(
1886 stop_events(remaining_events),
1887 vec![acp::StopReason::Cancelled],
1888 );
1889
1890 // Verify we can send a new message after cancellation
1891 verify_thread_recovery(&thread, &fake_model, cx).await;
1892}
1893
1894/// Helper to verify thread can recover after cancellation by sending a simple message.
1895async fn verify_thread_recovery(
1896 thread: &Entity<Thread>,
1897 fake_model: &FakeLanguageModel,
1898 cx: &mut TestAppContext,
1899) {
1900 let events = thread
1901 .update(cx, |thread, cx| {
1902 thread.send(
1903 UserMessageId::new(),
1904 ["Testing: reply with 'Hello' then stop."],
1905 cx,
1906 )
1907 })
1908 .unwrap();
1909 cx.run_until_parked();
1910 fake_model.send_last_completion_stream_text_chunk("Hello");
1911 fake_model
1912 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1913 fake_model.end_last_completion_stream();
1914
1915 let events = events.collect::<Vec<_>>().await;
1916 thread.update(cx, |thread, _cx| {
1917 let message = thread.last_message().unwrap();
1918 let agent_message = message.as_agent_message().unwrap();
1919 assert_eq!(
1920 agent_message.content,
1921 vec![AgentMessageContent::Text("Hello".to_string())]
1922 );
1923 });
1924 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1925}
1926
1927/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
1928async fn wait_for_terminal_tool_started(
1929 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1930 cx: &mut TestAppContext,
1931) {
1932 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
1933 for _ in 0..deadline {
1934 cx.run_until_parked();
1935
1936 while let Some(Some(event)) = events.next().now_or_never() {
1937 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1938 update,
1939 ))) = &event
1940 {
1941 if update.fields.content.as_ref().is_some_and(|content| {
1942 content
1943 .iter()
1944 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1945 }) {
1946 return;
1947 }
1948 }
1949 }
1950
1951 cx.background_executor
1952 .timer(Duration::from_millis(10))
1953 .await;
1954 }
1955 panic!("terminal tool did not start within the expected time");
1956}
1957
1958/// Collects events until a Stop event is received, driving the executor to completion.
1959async fn collect_events_until_stop(
1960 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1961 cx: &mut TestAppContext,
1962) -> Vec<Result<ThreadEvent>> {
1963 let mut collected = Vec::new();
1964 let deadline = cx.executor().num_cpus() * 200;
1965
1966 for _ in 0..deadline {
1967 cx.executor().advance_clock(Duration::from_millis(10));
1968 cx.run_until_parked();
1969
1970 while let Some(Some(event)) = events.next().now_or_never() {
1971 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
1972 collected.push(event);
1973 if is_stop {
1974 return collected;
1975 }
1976 }
1977 }
1978 panic!(
1979 "did not receive Stop event within the expected time; collected {} events",
1980 collected.len()
1981 );
1982}
1983
1984#[gpui::test]
1985async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
1986 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1987 always_allow_tools(cx);
1988 let fake_model = model.as_fake();
1989
1990 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1991 let environment = Rc::new(FakeThreadEnvironment {
1992 handle: handle.clone(),
1993 });
1994
1995 let message_id = UserMessageId::new();
1996 let mut events = thread
1997 .update(cx, |thread, cx| {
1998 thread.add_tool(crate::TerminalTool::new(
1999 thread.project().clone(),
2000 environment,
2001 ));
2002 thread.send(message_id.clone(), ["run a command"], cx)
2003 })
2004 .unwrap();
2005
2006 cx.run_until_parked();
2007
2008 // Simulate the model calling the terminal tool
2009 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2010 LanguageModelToolUse {
2011 id: "terminal_tool_1".into(),
2012 name: "terminal".into(),
2013 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2014 input: json!({"command": "sleep 1000", "cd": "."}),
2015 is_input_complete: true,
2016 thought_signature: None,
2017 },
2018 ));
2019 fake_model.end_last_completion_stream();
2020
2021 // Wait for the terminal tool to start running
2022 wait_for_terminal_tool_started(&mut events, cx).await;
2023
2024 // Truncate the thread while the terminal is running
2025 thread
2026 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2027 .unwrap();
2028
2029 // Drive the executor to let cancellation complete
2030 let _ = collect_events_until_stop(&mut events, cx).await;
2031
2032 // Verify the terminal was killed
2033 assert!(
2034 handle.was_killed(),
2035 "expected terminal handle to be killed on truncate"
2036 );
2037
2038 // Verify the thread is empty after truncation
2039 thread.update(cx, |thread, _cx| {
2040 assert_eq!(
2041 thread.to_markdown(),
2042 "",
2043 "expected thread to be empty after truncating the only message"
2044 );
2045 });
2046
2047 // Verify we can send a new message after truncation
2048 verify_thread_recovery(&thread, &fake_model, cx).await;
2049}
2050
2051#[gpui::test]
2052async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2053 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2054 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2055 always_allow_tools(cx);
2056 let fake_model = model.as_fake();
2057
2058 let environment = Rc::new(MultiTerminalEnvironment::new());
2059
2060 let mut events = thread
2061 .update(cx, |thread, cx| {
2062 thread.add_tool(crate::TerminalTool::new(
2063 thread.project().clone(),
2064 environment.clone(),
2065 ));
2066 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2067 })
2068 .unwrap();
2069
2070 cx.run_until_parked();
2071
2072 // Simulate the model calling two terminal tools
2073 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2074 LanguageModelToolUse {
2075 id: "terminal_tool_1".into(),
2076 name: "terminal".into(),
2077 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2078 input: json!({"command": "sleep 1000", "cd": "."}),
2079 is_input_complete: true,
2080 thought_signature: None,
2081 },
2082 ));
2083 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2084 LanguageModelToolUse {
2085 id: "terminal_tool_2".into(),
2086 name: "terminal".into(),
2087 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2088 input: json!({"command": "sleep 2000", "cd": "."}),
2089 is_input_complete: true,
2090 thought_signature: None,
2091 },
2092 ));
2093 fake_model.end_last_completion_stream();
2094
2095 // Wait for both terminal tools to start by counting terminal content updates
2096 let mut terminals_started = 0;
2097 let deadline = cx.executor().num_cpus() * 100;
2098 for _ in 0..deadline {
2099 cx.run_until_parked();
2100
2101 while let Some(Some(event)) = events.next().now_or_never() {
2102 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2103 update,
2104 ))) = &event
2105 {
2106 if update.fields.content.as_ref().is_some_and(|content| {
2107 content
2108 .iter()
2109 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2110 }) {
2111 terminals_started += 1;
2112 if terminals_started >= 2 {
2113 break;
2114 }
2115 }
2116 }
2117 }
2118 if terminals_started >= 2 {
2119 break;
2120 }
2121
2122 cx.background_executor
2123 .timer(Duration::from_millis(10))
2124 .await;
2125 }
2126 assert!(
2127 terminals_started >= 2,
2128 "expected 2 terminal tools to start, got {terminals_started}"
2129 );
2130
2131 // Cancel the thread while both terminals are running
2132 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2133
2134 // Collect remaining events
2135 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2136
2137 // Verify both terminal handles were killed
2138 let handles = environment.handles();
2139 assert_eq!(
2140 handles.len(),
2141 2,
2142 "expected 2 terminal handles to be created"
2143 );
2144 assert!(
2145 handles[0].was_killed(),
2146 "expected first terminal handle to be killed on cancellation"
2147 );
2148 assert!(
2149 handles[1].was_killed(),
2150 "expected second terminal handle to be killed on cancellation"
2151 );
2152
2153 // Verify we got a cancellation stop event
2154 assert_eq!(
2155 stop_events(remaining_events),
2156 vec![acp::StopReason::Cancelled],
2157 );
2158}
2159
2160#[gpui::test]
2161async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2162 // Tests that clicking the stop button on the terminal card (as opposed to the main
2163 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2164 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2165 always_allow_tools(cx);
2166 let fake_model = model.as_fake();
2167
2168 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2169 let environment = Rc::new(FakeThreadEnvironment {
2170 handle: handle.clone(),
2171 });
2172
2173 let mut events = thread
2174 .update(cx, |thread, cx| {
2175 thread.add_tool(crate::TerminalTool::new(
2176 thread.project().clone(),
2177 environment,
2178 ));
2179 thread.send(UserMessageId::new(), ["run a command"], cx)
2180 })
2181 .unwrap();
2182
2183 cx.run_until_parked();
2184
2185 // Simulate the model calling the terminal tool
2186 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2187 LanguageModelToolUse {
2188 id: "terminal_tool_1".into(),
2189 name: "terminal".into(),
2190 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2191 input: json!({"command": "sleep 1000", "cd": "."}),
2192 is_input_complete: true,
2193 thought_signature: None,
2194 },
2195 ));
2196 fake_model.end_last_completion_stream();
2197
2198 // Wait for the terminal tool to start running
2199 wait_for_terminal_tool_started(&mut events, cx).await;
2200
2201 // Simulate user clicking stop on the terminal card itself.
2202 // This sets the flag and signals exit (simulating what the real UI would do).
2203 handle.set_stopped_by_user(true);
2204 handle.killed.store(true, Ordering::SeqCst);
2205 handle.signal_exit();
2206
2207 // Wait for the tool to complete
2208 cx.run_until_parked();
2209
2210 // The thread continues after tool completion - simulate the model ending its turn
2211 fake_model
2212 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2213 fake_model.end_last_completion_stream();
2214
2215 // Collect remaining events
2216 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2217
2218 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2219 assert_eq!(
2220 stop_events(remaining_events),
2221 vec![acp::StopReason::EndTurn],
2222 );
2223
2224 // Verify the tool result indicates user stopped
2225 thread.update(cx, |thread, _cx| {
2226 let message = thread.last_message().unwrap();
2227 let agent_message = message.as_agent_message().unwrap();
2228
2229 let tool_use = agent_message
2230 .content
2231 .iter()
2232 .find_map(|content| match content {
2233 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2234 _ => None,
2235 })
2236 .expect("expected tool use in agent message");
2237
2238 let tool_result = agent_message
2239 .tool_results
2240 .get(&tool_use.id)
2241 .expect("expected tool result");
2242
2243 let result_text = match &tool_result.content {
2244 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2245 _ => panic!("expected text content in tool result"),
2246 };
2247
2248 assert!(
2249 result_text.contains("The user stopped this command"),
2250 "expected tool result to indicate user stopped, got: {result_text}"
2251 );
2252 });
2253}
2254
2255#[gpui::test]
2256async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2257 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2258 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2259 always_allow_tools(cx);
2260 let fake_model = model.as_fake();
2261
2262 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2263 let environment = Rc::new(FakeThreadEnvironment {
2264 handle: handle.clone(),
2265 });
2266
2267 let mut events = thread
2268 .update(cx, |thread, cx| {
2269 thread.add_tool(crate::TerminalTool::new(
2270 thread.project().clone(),
2271 environment,
2272 ));
2273 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2274 })
2275 .unwrap();
2276
2277 cx.run_until_parked();
2278
2279 // Simulate the model calling the terminal tool with a short timeout
2280 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2281 LanguageModelToolUse {
2282 id: "terminal_tool_1".into(),
2283 name: "terminal".into(),
2284 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2285 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2286 is_input_complete: true,
2287 thought_signature: None,
2288 },
2289 ));
2290 fake_model.end_last_completion_stream();
2291
2292 // Wait for the terminal tool to start running
2293 wait_for_terminal_tool_started(&mut events, cx).await;
2294
2295 // Advance clock past the timeout
2296 cx.executor().advance_clock(Duration::from_millis(200));
2297 cx.run_until_parked();
2298
2299 // The thread continues after tool completion - simulate the model ending its turn
2300 fake_model
2301 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2302 fake_model.end_last_completion_stream();
2303
2304 // Collect remaining events
2305 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2306
2307 // Verify the terminal was killed due to timeout
2308 assert!(
2309 handle.was_killed(),
2310 "expected terminal handle to be killed on timeout"
2311 );
2312
2313 // Verify we got an EndTurn (the tool completed, just with timeout)
2314 assert_eq!(
2315 stop_events(remaining_events),
2316 vec![acp::StopReason::EndTurn],
2317 );
2318
2319 // Verify the tool result indicates timeout, not user stopped
2320 thread.update(cx, |thread, _cx| {
2321 let message = thread.last_message().unwrap();
2322 let agent_message = message.as_agent_message().unwrap();
2323
2324 let tool_use = agent_message
2325 .content
2326 .iter()
2327 .find_map(|content| match content {
2328 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2329 _ => None,
2330 })
2331 .expect("expected tool use in agent message");
2332
2333 let tool_result = agent_message
2334 .tool_results
2335 .get(&tool_use.id)
2336 .expect("expected tool result");
2337
2338 let result_text = match &tool_result.content {
2339 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2340 _ => panic!("expected text content in tool result"),
2341 };
2342
2343 assert!(
2344 result_text.contains("timed out"),
2345 "expected tool result to indicate timeout, got: {result_text}"
2346 );
2347 assert!(
2348 !result_text.contains("The user stopped"),
2349 "tool result should not mention user stopped when it timed out, got: {result_text}"
2350 );
2351 });
2352}
2353
2354#[gpui::test]
2355async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2356 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2357 let fake_model = model.as_fake();
2358
2359 let events_1 = thread
2360 .update(cx, |thread, cx| {
2361 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2362 })
2363 .unwrap();
2364 cx.run_until_parked();
2365 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2366 cx.run_until_parked();
2367
2368 let events_2 = thread
2369 .update(cx, |thread, cx| {
2370 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2371 })
2372 .unwrap();
2373 cx.run_until_parked();
2374 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2375 fake_model
2376 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2377 fake_model.end_last_completion_stream();
2378
2379 let events_1 = events_1.collect::<Vec<_>>().await;
2380 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2381 let events_2 = events_2.collect::<Vec<_>>().await;
2382 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2383}
2384
2385#[gpui::test]
2386async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2387 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2388 let fake_model = model.as_fake();
2389
2390 let events_1 = thread
2391 .update(cx, |thread, cx| {
2392 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2393 })
2394 .unwrap();
2395 cx.run_until_parked();
2396 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2397 fake_model
2398 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2399 fake_model.end_last_completion_stream();
2400 let events_1 = events_1.collect::<Vec<_>>().await;
2401
2402 let events_2 = thread
2403 .update(cx, |thread, cx| {
2404 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2405 })
2406 .unwrap();
2407 cx.run_until_parked();
2408 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2409 fake_model
2410 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2411 fake_model.end_last_completion_stream();
2412 let events_2 = events_2.collect::<Vec<_>>().await;
2413
2414 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2415 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2416}
2417
2418#[gpui::test]
2419async fn test_refusal(cx: &mut TestAppContext) {
2420 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2421 let fake_model = model.as_fake();
2422
2423 let events = thread
2424 .update(cx, |thread, cx| {
2425 thread.send(UserMessageId::new(), ["Hello"], cx)
2426 })
2427 .unwrap();
2428 cx.run_until_parked();
2429 thread.read_with(cx, |thread, _| {
2430 assert_eq!(
2431 thread.to_markdown(),
2432 indoc! {"
2433 ## User
2434
2435 Hello
2436 "}
2437 );
2438 });
2439
2440 fake_model.send_last_completion_stream_text_chunk("Hey!");
2441 cx.run_until_parked();
2442 thread.read_with(cx, |thread, _| {
2443 assert_eq!(
2444 thread.to_markdown(),
2445 indoc! {"
2446 ## User
2447
2448 Hello
2449
2450 ## Assistant
2451
2452 Hey!
2453 "}
2454 );
2455 });
2456
2457 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2458 fake_model
2459 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2460 let events = events.collect::<Vec<_>>().await;
2461 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2462 thread.read_with(cx, |thread, _| {
2463 assert_eq!(thread.to_markdown(), "");
2464 });
2465}
2466
2467#[gpui::test]
2468async fn test_truncate_first_message(cx: &mut TestAppContext) {
2469 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2470 let fake_model = model.as_fake();
2471
2472 let message_id = UserMessageId::new();
2473 thread
2474 .update(cx, |thread, cx| {
2475 thread.send(message_id.clone(), ["Hello"], cx)
2476 })
2477 .unwrap();
2478 cx.run_until_parked();
2479 thread.read_with(cx, |thread, _| {
2480 assert_eq!(
2481 thread.to_markdown(),
2482 indoc! {"
2483 ## User
2484
2485 Hello
2486 "}
2487 );
2488 assert_eq!(thread.latest_token_usage(), None);
2489 });
2490
2491 fake_model.send_last_completion_stream_text_chunk("Hey!");
2492 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2493 language_model::TokenUsage {
2494 input_tokens: 32_000,
2495 output_tokens: 16_000,
2496 cache_creation_input_tokens: 0,
2497 cache_read_input_tokens: 0,
2498 },
2499 ));
2500 cx.run_until_parked();
2501 thread.read_with(cx, |thread, _| {
2502 assert_eq!(
2503 thread.to_markdown(),
2504 indoc! {"
2505 ## User
2506
2507 Hello
2508
2509 ## Assistant
2510
2511 Hey!
2512 "}
2513 );
2514 assert_eq!(
2515 thread.latest_token_usage(),
2516 Some(acp_thread::TokenUsage {
2517 used_tokens: 32_000 + 16_000,
2518 max_tokens: 1_000_000,
2519 output_tokens: 16_000,
2520 })
2521 );
2522 });
2523
2524 thread
2525 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2526 .unwrap();
2527 cx.run_until_parked();
2528 thread.read_with(cx, |thread, _| {
2529 assert_eq!(thread.to_markdown(), "");
2530 assert_eq!(thread.latest_token_usage(), None);
2531 });
2532
2533 // Ensure we can still send a new message after truncation.
2534 thread
2535 .update(cx, |thread, cx| {
2536 thread.send(UserMessageId::new(), ["Hi"], cx)
2537 })
2538 .unwrap();
2539 thread.update(cx, |thread, _cx| {
2540 assert_eq!(
2541 thread.to_markdown(),
2542 indoc! {"
2543 ## User
2544
2545 Hi
2546 "}
2547 );
2548 });
2549 cx.run_until_parked();
2550 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2551 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2552 language_model::TokenUsage {
2553 input_tokens: 40_000,
2554 output_tokens: 20_000,
2555 cache_creation_input_tokens: 0,
2556 cache_read_input_tokens: 0,
2557 },
2558 ));
2559 cx.run_until_parked();
2560 thread.read_with(cx, |thread, _| {
2561 assert_eq!(
2562 thread.to_markdown(),
2563 indoc! {"
2564 ## User
2565
2566 Hi
2567
2568 ## Assistant
2569
2570 Ahoy!
2571 "}
2572 );
2573
2574 assert_eq!(
2575 thread.latest_token_usage(),
2576 Some(acp_thread::TokenUsage {
2577 used_tokens: 40_000 + 20_000,
2578 max_tokens: 1_000_000,
2579 output_tokens: 20_000,
2580 })
2581 );
2582 });
2583}
2584
2585#[gpui::test]
2586async fn test_truncate_second_message(cx: &mut TestAppContext) {
2587 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2588 let fake_model = model.as_fake();
2589
2590 thread
2591 .update(cx, |thread, cx| {
2592 thread.send(UserMessageId::new(), ["Message 1"], cx)
2593 })
2594 .unwrap();
2595 cx.run_until_parked();
2596 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2597 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2598 language_model::TokenUsage {
2599 input_tokens: 32_000,
2600 output_tokens: 16_000,
2601 cache_creation_input_tokens: 0,
2602 cache_read_input_tokens: 0,
2603 },
2604 ));
2605 fake_model.end_last_completion_stream();
2606 cx.run_until_parked();
2607
2608 let assert_first_message_state = |cx: &mut TestAppContext| {
2609 thread.clone().read_with(cx, |thread, _| {
2610 assert_eq!(
2611 thread.to_markdown(),
2612 indoc! {"
2613 ## User
2614
2615 Message 1
2616
2617 ## Assistant
2618
2619 Message 1 response
2620 "}
2621 );
2622
2623 assert_eq!(
2624 thread.latest_token_usage(),
2625 Some(acp_thread::TokenUsage {
2626 used_tokens: 32_000 + 16_000,
2627 max_tokens: 1_000_000,
2628 output_tokens: 16_000,
2629 })
2630 );
2631 });
2632 };
2633
2634 assert_first_message_state(cx);
2635
2636 let second_message_id = UserMessageId::new();
2637 thread
2638 .update(cx, |thread, cx| {
2639 thread.send(second_message_id.clone(), ["Message 2"], cx)
2640 })
2641 .unwrap();
2642 cx.run_until_parked();
2643
2644 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2645 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2646 language_model::TokenUsage {
2647 input_tokens: 40_000,
2648 output_tokens: 20_000,
2649 cache_creation_input_tokens: 0,
2650 cache_read_input_tokens: 0,
2651 },
2652 ));
2653 fake_model.end_last_completion_stream();
2654 cx.run_until_parked();
2655
2656 thread.read_with(cx, |thread, _| {
2657 assert_eq!(
2658 thread.to_markdown(),
2659 indoc! {"
2660 ## User
2661
2662 Message 1
2663
2664 ## Assistant
2665
2666 Message 1 response
2667
2668 ## User
2669
2670 Message 2
2671
2672 ## Assistant
2673
2674 Message 2 response
2675 "}
2676 );
2677
2678 assert_eq!(
2679 thread.latest_token_usage(),
2680 Some(acp_thread::TokenUsage {
2681 used_tokens: 40_000 + 20_000,
2682 max_tokens: 1_000_000,
2683 output_tokens: 20_000,
2684 })
2685 );
2686 });
2687
2688 thread
2689 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2690 .unwrap();
2691 cx.run_until_parked();
2692
2693 assert_first_message_state(cx);
2694}
2695
2696#[gpui::test]
2697async fn test_title_generation(cx: &mut TestAppContext) {
2698 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2699 let fake_model = model.as_fake();
2700
2701 let summary_model = Arc::new(FakeLanguageModel::default());
2702 thread.update(cx, |thread, cx| {
2703 thread.set_summarization_model(Some(summary_model.clone()), cx)
2704 });
2705
2706 let send = thread
2707 .update(cx, |thread, cx| {
2708 thread.send(UserMessageId::new(), ["Hello"], cx)
2709 })
2710 .unwrap();
2711 cx.run_until_parked();
2712
2713 fake_model.send_last_completion_stream_text_chunk("Hey!");
2714 fake_model.end_last_completion_stream();
2715 cx.run_until_parked();
2716 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2717
2718 // Ensure the summary model has been invoked to generate a title.
2719 summary_model.send_last_completion_stream_text_chunk("Hello ");
2720 summary_model.send_last_completion_stream_text_chunk("world\nG");
2721 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2722 summary_model.end_last_completion_stream();
2723 send.collect::<Vec<_>>().await;
2724 cx.run_until_parked();
2725 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2726
2727 // Send another message, ensuring no title is generated this time.
2728 let send = thread
2729 .update(cx, |thread, cx| {
2730 thread.send(UserMessageId::new(), ["Hello again"], cx)
2731 })
2732 .unwrap();
2733 cx.run_until_parked();
2734 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2735 fake_model.end_last_completion_stream();
2736 cx.run_until_parked();
2737 assert_eq!(summary_model.pending_completions(), Vec::new());
2738 send.collect::<Vec<_>>().await;
2739 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2740}
2741
2742#[gpui::test]
2743async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2744 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2745 let fake_model = model.as_fake();
2746
2747 let _events = thread
2748 .update(cx, |thread, cx| {
2749 thread.add_tool(ToolRequiringPermission);
2750 thread.add_tool(EchoTool);
2751 thread.send(UserMessageId::new(), ["Hey!"], cx)
2752 })
2753 .unwrap();
2754 cx.run_until_parked();
2755
2756 let permission_tool_use = LanguageModelToolUse {
2757 id: "tool_id_1".into(),
2758 name: ToolRequiringPermission::name().into(),
2759 raw_input: "{}".into(),
2760 input: json!({}),
2761 is_input_complete: true,
2762 thought_signature: None,
2763 };
2764 let echo_tool_use = LanguageModelToolUse {
2765 id: "tool_id_2".into(),
2766 name: EchoTool::name().into(),
2767 raw_input: json!({"text": "test"}).to_string(),
2768 input: json!({"text": "test"}),
2769 is_input_complete: true,
2770 thought_signature: None,
2771 };
2772 fake_model.send_last_completion_stream_text_chunk("Hi!");
2773 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2774 permission_tool_use,
2775 ));
2776 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2777 echo_tool_use.clone(),
2778 ));
2779 fake_model.end_last_completion_stream();
2780 cx.run_until_parked();
2781
2782 // Ensure pending tools are skipped when building a request.
2783 let request = thread
2784 .read_with(cx, |thread, cx| {
2785 thread.build_completion_request(CompletionIntent::EditFile, cx)
2786 })
2787 .unwrap();
2788 assert_eq!(
2789 request.messages[1..],
2790 vec![
2791 LanguageModelRequestMessage {
2792 role: Role::User,
2793 content: vec!["Hey!".into()],
2794 cache: true,
2795 reasoning_details: None,
2796 },
2797 LanguageModelRequestMessage {
2798 role: Role::Assistant,
2799 content: vec![
2800 MessageContent::Text("Hi!".into()),
2801 MessageContent::ToolUse(echo_tool_use.clone())
2802 ],
2803 cache: false,
2804 reasoning_details: None,
2805 },
2806 LanguageModelRequestMessage {
2807 role: Role::User,
2808 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2809 tool_use_id: echo_tool_use.id.clone(),
2810 tool_name: echo_tool_use.name,
2811 is_error: false,
2812 content: "test".into(),
2813 output: Some("test".into())
2814 })],
2815 cache: false,
2816 reasoning_details: None,
2817 },
2818 ],
2819 );
2820}
2821
2822#[gpui::test]
2823async fn test_agent_connection(cx: &mut TestAppContext) {
2824 cx.update(settings::init);
2825 let templates = Templates::new();
2826
2827 // Initialize language model system with test provider
2828 cx.update(|cx| {
2829 gpui_tokio::init(cx);
2830
2831 let http_client = FakeHttpClient::with_404_response();
2832 let clock = Arc::new(clock::FakeSystemClock::new());
2833 let client = Client::new(clock, http_client, cx);
2834 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2835 language_model::init(client.clone(), cx);
2836 language_models::init(user_store, client.clone(), cx);
2837 LanguageModelRegistry::test(cx);
2838 });
2839 cx.executor().forbid_parking();
2840
2841 // Create a project for new_thread
2842 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2843 fake_fs.insert_tree(path!("/test"), json!({})).await;
2844 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2845 let cwd = Path::new("/test");
2846 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2847
2848 // Create agent and connection
2849 let agent = NativeAgent::new(
2850 project.clone(),
2851 thread_store,
2852 templates.clone(),
2853 None,
2854 fake_fs.clone(),
2855 &mut cx.to_async(),
2856 )
2857 .await
2858 .unwrap();
2859 let connection = NativeAgentConnection(agent.clone());
2860
2861 // Create a thread using new_thread
2862 let connection_rc = Rc::new(connection.clone());
2863 let acp_thread = cx
2864 .update(|cx| connection_rc.new_thread(project, cwd, cx))
2865 .await
2866 .expect("new_thread should succeed");
2867
2868 // Get the session_id from the AcpThread
2869 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2870
2871 // Test model_selector returns Some
2872 let selector_opt = connection.model_selector(&session_id);
2873 assert!(
2874 selector_opt.is_some(),
2875 "agent should always support ModelSelector"
2876 );
2877 let selector = selector_opt.unwrap();
2878
2879 // Test list_models
2880 let listed_models = cx
2881 .update(|cx| selector.list_models(cx))
2882 .await
2883 .expect("list_models should succeed");
2884 let AgentModelList::Grouped(listed_models) = listed_models else {
2885 panic!("Unexpected model list type");
2886 };
2887 assert!(!listed_models.is_empty(), "should have at least one model");
2888 assert_eq!(
2889 listed_models[&AgentModelGroupName("Fake".into())][0]
2890 .id
2891 .0
2892 .as_ref(),
2893 "fake/fake"
2894 );
2895
2896 // Test selected_model returns the default
2897 let model = cx
2898 .update(|cx| selector.selected_model(cx))
2899 .await
2900 .expect("selected_model should succeed");
2901 let model = cx
2902 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2903 .unwrap();
2904 let model = model.as_fake();
2905 assert_eq!(model.id().0, "fake", "should return default model");
2906
2907 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2908 cx.run_until_parked();
2909 model.send_last_completion_stream_text_chunk("def");
2910 cx.run_until_parked();
2911 acp_thread.read_with(cx, |thread, cx| {
2912 assert_eq!(
2913 thread.to_markdown(cx),
2914 indoc! {"
2915 ## User
2916
2917 abc
2918
2919 ## Assistant
2920
2921 def
2922
2923 "}
2924 )
2925 });
2926
2927 // Test cancel
2928 cx.update(|cx| connection.cancel(&session_id, cx));
2929 request.await.expect("prompt should fail gracefully");
2930
2931 // Ensure that dropping the ACP thread causes the native thread to be
2932 // dropped as well.
2933 cx.update(|_| drop(acp_thread));
2934 let result = cx
2935 .update(|cx| {
2936 connection.prompt(
2937 Some(acp_thread::UserMessageId::new()),
2938 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2939 cx,
2940 )
2941 })
2942 .await;
2943 assert_eq!(
2944 result.as_ref().unwrap_err().to_string(),
2945 "Session not found",
2946 "unexpected result: {:?}",
2947 result
2948 );
2949}
2950
2951#[gpui::test]
2952async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2953 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2954 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2955 let fake_model = model.as_fake();
2956
2957 let mut events = thread
2958 .update(cx, |thread, cx| {
2959 thread.send(UserMessageId::new(), ["Think"], cx)
2960 })
2961 .unwrap();
2962 cx.run_until_parked();
2963
2964 // Simulate streaming partial input.
2965 let input = json!({});
2966 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2967 LanguageModelToolUse {
2968 id: "1".into(),
2969 name: ThinkingTool::name().into(),
2970 raw_input: input.to_string(),
2971 input,
2972 is_input_complete: false,
2973 thought_signature: None,
2974 },
2975 ));
2976
2977 // Input streaming completed
2978 let input = json!({ "content": "Thinking hard!" });
2979 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2980 LanguageModelToolUse {
2981 id: "1".into(),
2982 name: "thinking".into(),
2983 raw_input: input.to_string(),
2984 input,
2985 is_input_complete: true,
2986 thought_signature: None,
2987 },
2988 ));
2989 fake_model.end_last_completion_stream();
2990 cx.run_until_parked();
2991
2992 let tool_call = expect_tool_call(&mut events).await;
2993 assert_eq!(
2994 tool_call,
2995 acp::ToolCall::new("1", "Thinking")
2996 .kind(acp::ToolKind::Think)
2997 .raw_input(json!({}))
2998 .meta(acp::Meta::from_iter([(
2999 "tool_name".into(),
3000 "thinking".into()
3001 )]))
3002 );
3003 let update = expect_tool_call_update_fields(&mut events).await;
3004 assert_eq!(
3005 update,
3006 acp::ToolCallUpdate::new(
3007 "1",
3008 acp::ToolCallUpdateFields::new()
3009 .title("Thinking")
3010 .kind(acp::ToolKind::Think)
3011 .raw_input(json!({ "content": "Thinking hard!"}))
3012 )
3013 );
3014 let update = expect_tool_call_update_fields(&mut events).await;
3015 assert_eq!(
3016 update,
3017 acp::ToolCallUpdate::new(
3018 "1",
3019 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3020 )
3021 );
3022 let update = expect_tool_call_update_fields(&mut events).await;
3023 assert_eq!(
3024 update,
3025 acp::ToolCallUpdate::new(
3026 "1",
3027 acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
3028 )
3029 );
3030 let update = expect_tool_call_update_fields(&mut events).await;
3031 assert_eq!(
3032 update,
3033 acp::ToolCallUpdate::new(
3034 "1",
3035 acp::ToolCallUpdateFields::new()
3036 .status(acp::ToolCallStatus::Completed)
3037 .raw_output("Finished thinking.")
3038 )
3039 );
3040}
3041
3042#[gpui::test]
3043async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3044 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3045 let fake_model = model.as_fake();
3046
3047 let mut events = thread
3048 .update(cx, |thread, cx| {
3049 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3050 thread.send(UserMessageId::new(), ["Hello!"], cx)
3051 })
3052 .unwrap();
3053 cx.run_until_parked();
3054
3055 fake_model.send_last_completion_stream_text_chunk("Hey!");
3056 fake_model.end_last_completion_stream();
3057
3058 let mut retry_events = Vec::new();
3059 while let Some(Ok(event)) = events.next().await {
3060 match event {
3061 ThreadEvent::Retry(retry_status) => {
3062 retry_events.push(retry_status);
3063 }
3064 ThreadEvent::Stop(..) => break,
3065 _ => {}
3066 }
3067 }
3068
3069 assert_eq!(retry_events.len(), 0);
3070 thread.read_with(cx, |thread, _cx| {
3071 assert_eq!(
3072 thread.to_markdown(),
3073 indoc! {"
3074 ## User
3075
3076 Hello!
3077
3078 ## Assistant
3079
3080 Hey!
3081 "}
3082 )
3083 });
3084}
3085
3086#[gpui::test]
3087async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3088 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3089 let fake_model = model.as_fake();
3090
3091 let mut events = thread
3092 .update(cx, |thread, cx| {
3093 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3094 thread.send(UserMessageId::new(), ["Hello!"], cx)
3095 })
3096 .unwrap();
3097 cx.run_until_parked();
3098
3099 fake_model.send_last_completion_stream_text_chunk("Hey,");
3100 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3101 provider: LanguageModelProviderName::new("Anthropic"),
3102 retry_after: Some(Duration::from_secs(3)),
3103 });
3104 fake_model.end_last_completion_stream();
3105
3106 cx.executor().advance_clock(Duration::from_secs(3));
3107 cx.run_until_parked();
3108
3109 fake_model.send_last_completion_stream_text_chunk("there!");
3110 fake_model.end_last_completion_stream();
3111 cx.run_until_parked();
3112
3113 let mut retry_events = Vec::new();
3114 while let Some(Ok(event)) = events.next().await {
3115 match event {
3116 ThreadEvent::Retry(retry_status) => {
3117 retry_events.push(retry_status);
3118 }
3119 ThreadEvent::Stop(..) => break,
3120 _ => {}
3121 }
3122 }
3123
3124 assert_eq!(retry_events.len(), 1);
3125 assert!(matches!(
3126 retry_events[0],
3127 acp_thread::RetryStatus { attempt: 1, .. }
3128 ));
3129 thread.read_with(cx, |thread, _cx| {
3130 assert_eq!(
3131 thread.to_markdown(),
3132 indoc! {"
3133 ## User
3134
3135 Hello!
3136
3137 ## Assistant
3138
3139 Hey,
3140
3141 [resume]
3142
3143 ## Assistant
3144
3145 there!
3146 "}
3147 )
3148 });
3149}
3150
3151#[gpui::test]
3152async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3153 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3154 let fake_model = model.as_fake();
3155
3156 let events = thread
3157 .update(cx, |thread, cx| {
3158 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3159 thread.add_tool(EchoTool);
3160 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3161 })
3162 .unwrap();
3163 cx.run_until_parked();
3164
3165 let tool_use_1 = LanguageModelToolUse {
3166 id: "tool_1".into(),
3167 name: EchoTool::name().into(),
3168 raw_input: json!({"text": "test"}).to_string(),
3169 input: json!({"text": "test"}),
3170 is_input_complete: true,
3171 thought_signature: None,
3172 };
3173 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3174 tool_use_1.clone(),
3175 ));
3176 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3177 provider: LanguageModelProviderName::new("Anthropic"),
3178 retry_after: Some(Duration::from_secs(3)),
3179 });
3180 fake_model.end_last_completion_stream();
3181
3182 cx.executor().advance_clock(Duration::from_secs(3));
3183 let completion = fake_model.pending_completions().pop().unwrap();
3184 assert_eq!(
3185 completion.messages[1..],
3186 vec![
3187 LanguageModelRequestMessage {
3188 role: Role::User,
3189 content: vec!["Call the echo tool!".into()],
3190 cache: false,
3191 reasoning_details: None,
3192 },
3193 LanguageModelRequestMessage {
3194 role: Role::Assistant,
3195 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3196 cache: false,
3197 reasoning_details: None,
3198 },
3199 LanguageModelRequestMessage {
3200 role: Role::User,
3201 content: vec![language_model::MessageContent::ToolResult(
3202 LanguageModelToolResult {
3203 tool_use_id: tool_use_1.id.clone(),
3204 tool_name: tool_use_1.name.clone(),
3205 is_error: false,
3206 content: "test".into(),
3207 output: Some("test".into())
3208 }
3209 )],
3210 cache: true,
3211 reasoning_details: None,
3212 },
3213 ]
3214 );
3215
3216 fake_model.send_last_completion_stream_text_chunk("Done");
3217 fake_model.end_last_completion_stream();
3218 cx.run_until_parked();
3219 events.collect::<Vec<_>>().await;
3220 thread.read_with(cx, |thread, _cx| {
3221 assert_eq!(
3222 thread.last_message(),
3223 Some(Message::Agent(AgentMessage {
3224 content: vec![AgentMessageContent::Text("Done".into())],
3225 tool_results: IndexMap::default(),
3226 reasoning_details: None,
3227 }))
3228 );
3229 })
3230}
3231
3232#[gpui::test]
3233async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3234 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3235 let fake_model = model.as_fake();
3236
3237 let mut events = thread
3238 .update(cx, |thread, cx| {
3239 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3240 thread.send(UserMessageId::new(), ["Hello!"], cx)
3241 })
3242 .unwrap();
3243 cx.run_until_parked();
3244
3245 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3246 fake_model.send_last_completion_stream_error(
3247 LanguageModelCompletionError::ServerOverloaded {
3248 provider: LanguageModelProviderName::new("Anthropic"),
3249 retry_after: Some(Duration::from_secs(3)),
3250 },
3251 );
3252 fake_model.end_last_completion_stream();
3253 cx.executor().advance_clock(Duration::from_secs(3));
3254 cx.run_until_parked();
3255 }
3256
3257 let mut errors = Vec::new();
3258 let mut retry_events = Vec::new();
3259 while let Some(event) = events.next().await {
3260 match event {
3261 Ok(ThreadEvent::Retry(retry_status)) => {
3262 retry_events.push(retry_status);
3263 }
3264 Ok(ThreadEvent::Stop(..)) => break,
3265 Err(error) => errors.push(error),
3266 _ => {}
3267 }
3268 }
3269
3270 assert_eq!(
3271 retry_events.len(),
3272 crate::thread::MAX_RETRY_ATTEMPTS as usize
3273 );
3274 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3275 assert_eq!(retry_events[i].attempt, i + 1);
3276 }
3277 assert_eq!(errors.len(), 1);
3278 let error = errors[0]
3279 .downcast_ref::<LanguageModelCompletionError>()
3280 .unwrap();
3281 assert!(matches!(
3282 error,
3283 LanguageModelCompletionError::ServerOverloaded { .. }
3284 ));
3285}
3286
3287/// Filters out the stop events for asserting against in tests
3288fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3289 result_events
3290 .into_iter()
3291 .filter_map(|event| match event.unwrap() {
3292 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3293 _ => None,
3294 })
3295 .collect()
3296}
3297
3298struct ThreadTest {
3299 model: Arc<dyn LanguageModel>,
3300 thread: Entity<Thread>,
3301 project_context: Entity<ProjectContext>,
3302 context_server_store: Entity<ContextServerStore>,
3303 fs: Arc<FakeFs>,
3304}
3305
3306enum TestModel {
3307 Sonnet4,
3308 Fake,
3309}
3310
3311impl TestModel {
3312 fn id(&self) -> LanguageModelId {
3313 match self {
3314 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3315 TestModel::Fake => unreachable!(),
3316 }
3317 }
3318}
3319
3320async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3321 cx.executor().allow_parking();
3322
3323 let fs = FakeFs::new(cx.background_executor.clone());
3324 fs.create_dir(paths::settings_file().parent().unwrap())
3325 .await
3326 .unwrap();
3327 fs.insert_file(
3328 paths::settings_file(),
3329 json!({
3330 "agent": {
3331 "default_profile": "test-profile",
3332 "profiles": {
3333 "test-profile": {
3334 "name": "Test Profile",
3335 "tools": {
3336 EchoTool::name(): true,
3337 DelayTool::name(): true,
3338 WordListTool::name(): true,
3339 ToolRequiringPermission::name(): true,
3340 InfiniteTool::name(): true,
3341 CancellationAwareTool::name(): true,
3342 ThinkingTool::name(): true,
3343 "terminal": true,
3344 }
3345 }
3346 }
3347 }
3348 })
3349 .to_string()
3350 .into_bytes(),
3351 )
3352 .await;
3353
3354 cx.update(|cx| {
3355 settings::init(cx);
3356
3357 match model {
3358 TestModel::Fake => {}
3359 TestModel::Sonnet4 => {
3360 gpui_tokio::init(cx);
3361 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3362 cx.set_http_client(Arc::new(http_client));
3363 let client = Client::production(cx);
3364 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3365 language_model::init(client.clone(), cx);
3366 language_models::init(user_store, client.clone(), cx);
3367 }
3368 };
3369
3370 watch_settings(fs.clone(), cx);
3371 });
3372
3373 let templates = Templates::new();
3374
3375 fs.insert_tree(path!("/test"), json!({})).await;
3376 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3377
3378 let model = cx
3379 .update(|cx| {
3380 if let TestModel::Fake = model {
3381 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3382 } else {
3383 let model_id = model.id();
3384 let models = LanguageModelRegistry::read_global(cx);
3385 let model = models
3386 .available_models(cx)
3387 .find(|model| model.id() == model_id)
3388 .unwrap();
3389
3390 let provider = models.provider(&model.provider_id()).unwrap();
3391 let authenticated = provider.authenticate(cx);
3392
3393 cx.spawn(async move |_cx| {
3394 authenticated.await.unwrap();
3395 model
3396 })
3397 }
3398 })
3399 .await;
3400
3401 let project_context = cx.new(|_cx| ProjectContext::default());
3402 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3403 let context_server_registry =
3404 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3405 let thread = cx.new(|cx| {
3406 Thread::new(
3407 project,
3408 project_context.clone(),
3409 context_server_registry,
3410 templates,
3411 Some(model.clone()),
3412 cx,
3413 )
3414 });
3415 ThreadTest {
3416 model,
3417 thread,
3418 project_context,
3419 context_server_store,
3420 fs,
3421 }
3422}
3423
3424#[cfg(test)]
3425#[ctor::ctor]
3426fn init_logger() {
3427 if std::env::var("RUST_LOG").is_ok() {
3428 env_logger::init();
3429 }
3430}
3431
3432fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3433 let fs = fs.clone();
3434 cx.spawn({
3435 async move |cx| {
3436 let mut new_settings_content_rx = settings::watch_config_file(
3437 cx.background_executor(),
3438 fs,
3439 paths::settings_file().clone(),
3440 );
3441
3442 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3443 cx.update(|cx| {
3444 SettingsStore::update_global(cx, |settings, cx| {
3445 settings.set_user_settings(&new_settings_content, cx)
3446 })
3447 })
3448 .ok();
3449 }
3450 }
3451 })
3452 .detach();
3453}
3454
3455fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3456 completion
3457 .tools
3458 .iter()
3459 .map(|tool| tool.name.clone())
3460 .collect()
3461}
3462
3463fn setup_context_server(
3464 name: &'static str,
3465 tools: Vec<context_server::types::Tool>,
3466 context_server_store: &Entity<ContextServerStore>,
3467 cx: &mut TestAppContext,
3468) -> mpsc::UnboundedReceiver<(
3469 context_server::types::CallToolParams,
3470 oneshot::Sender<context_server::types::CallToolResponse>,
3471)> {
3472 cx.update(|cx| {
3473 let mut settings = ProjectSettings::get_global(cx).clone();
3474 settings.context_servers.insert(
3475 name.into(),
3476 project::project_settings::ContextServerSettings::Stdio {
3477 enabled: true,
3478 remote: false,
3479 command: ContextServerCommand {
3480 path: "somebinary".into(),
3481 args: Vec::new(),
3482 env: None,
3483 timeout: None,
3484 },
3485 },
3486 );
3487 ProjectSettings::override_global(settings, cx);
3488 });
3489
3490 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3491 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3492 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3493 context_server::types::InitializeResponse {
3494 protocol_version: context_server::types::ProtocolVersion(
3495 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3496 ),
3497 server_info: context_server::types::Implementation {
3498 name: name.into(),
3499 version: "1.0.0".to_string(),
3500 },
3501 capabilities: context_server::types::ServerCapabilities {
3502 tools: Some(context_server::types::ToolsCapabilities {
3503 list_changed: Some(true),
3504 }),
3505 ..Default::default()
3506 },
3507 meta: None,
3508 }
3509 })
3510 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3511 let tools = tools.clone();
3512 async move {
3513 context_server::types::ListToolsResponse {
3514 tools,
3515 next_cursor: None,
3516 meta: None,
3517 }
3518 }
3519 })
3520 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3521 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3522 async move {
3523 let (response_tx, response_rx) = oneshot::channel();
3524 mcp_tool_calls_tx
3525 .unbounded_send((params, response_tx))
3526 .unwrap();
3527 response_rx.await.unwrap()
3528 }
3529 });
3530 context_server_store.update(cx, |store, cx| {
3531 store.start_server(
3532 Arc::new(ContextServer::new(
3533 ContextServerId(name.into()),
3534 Arc::new(fake_transport),
3535 )),
3536 cx,
3537 );
3538 });
3539 cx.run_until_parked();
3540 mcp_tool_calls_rx
3541}
3542
3543#[gpui::test]
3544async fn test_tokens_before_message(cx: &mut TestAppContext) {
3545 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3546 let fake_model = model.as_fake();
3547
3548 // First message
3549 let message_1_id = UserMessageId::new();
3550 thread
3551 .update(cx, |thread, cx| {
3552 thread.send(message_1_id.clone(), ["First message"], cx)
3553 })
3554 .unwrap();
3555 cx.run_until_parked();
3556
3557 // Before any response, tokens_before_message should return None for first message
3558 thread.read_with(cx, |thread, _| {
3559 assert_eq!(
3560 thread.tokens_before_message(&message_1_id),
3561 None,
3562 "First message should have no tokens before it"
3563 );
3564 });
3565
3566 // Complete first message with usage
3567 fake_model.send_last_completion_stream_text_chunk("Response 1");
3568 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3569 language_model::TokenUsage {
3570 input_tokens: 100,
3571 output_tokens: 50,
3572 cache_creation_input_tokens: 0,
3573 cache_read_input_tokens: 0,
3574 },
3575 ));
3576 fake_model.end_last_completion_stream();
3577 cx.run_until_parked();
3578
3579 // First message still has no tokens before it
3580 thread.read_with(cx, |thread, _| {
3581 assert_eq!(
3582 thread.tokens_before_message(&message_1_id),
3583 None,
3584 "First message should still have no tokens before it after response"
3585 );
3586 });
3587
3588 // Second message
3589 let message_2_id = UserMessageId::new();
3590 thread
3591 .update(cx, |thread, cx| {
3592 thread.send(message_2_id.clone(), ["Second message"], cx)
3593 })
3594 .unwrap();
3595 cx.run_until_parked();
3596
3597 // Second message should have first message's input tokens before it
3598 thread.read_with(cx, |thread, _| {
3599 assert_eq!(
3600 thread.tokens_before_message(&message_2_id),
3601 Some(100),
3602 "Second message should have 100 tokens before it (from first request)"
3603 );
3604 });
3605
3606 // Complete second message
3607 fake_model.send_last_completion_stream_text_chunk("Response 2");
3608 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3609 language_model::TokenUsage {
3610 input_tokens: 250, // Total for this request (includes previous context)
3611 output_tokens: 75,
3612 cache_creation_input_tokens: 0,
3613 cache_read_input_tokens: 0,
3614 },
3615 ));
3616 fake_model.end_last_completion_stream();
3617 cx.run_until_parked();
3618
3619 // Third message
3620 let message_3_id = UserMessageId::new();
3621 thread
3622 .update(cx, |thread, cx| {
3623 thread.send(message_3_id.clone(), ["Third message"], cx)
3624 })
3625 .unwrap();
3626 cx.run_until_parked();
3627
3628 // Third message should have second message's input tokens (250) before it
3629 thread.read_with(cx, |thread, _| {
3630 assert_eq!(
3631 thread.tokens_before_message(&message_3_id),
3632 Some(250),
3633 "Third message should have 250 tokens before it (from second request)"
3634 );
3635 // Second message should still have 100
3636 assert_eq!(
3637 thread.tokens_before_message(&message_2_id),
3638 Some(100),
3639 "Second message should still have 100 tokens before it"
3640 );
3641 // First message still has none
3642 assert_eq!(
3643 thread.tokens_before_message(&message_1_id),
3644 None,
3645 "First message should still have no tokens before it"
3646 );
3647 });
3648}
3649
3650#[gpui::test]
3651async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3652 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3653 let fake_model = model.as_fake();
3654
3655 // Set up three messages with responses
3656 let message_1_id = UserMessageId::new();
3657 thread
3658 .update(cx, |thread, cx| {
3659 thread.send(message_1_id.clone(), ["Message 1"], cx)
3660 })
3661 .unwrap();
3662 cx.run_until_parked();
3663 fake_model.send_last_completion_stream_text_chunk("Response 1");
3664 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3665 language_model::TokenUsage {
3666 input_tokens: 100,
3667 output_tokens: 50,
3668 cache_creation_input_tokens: 0,
3669 cache_read_input_tokens: 0,
3670 },
3671 ));
3672 fake_model.end_last_completion_stream();
3673 cx.run_until_parked();
3674
3675 let message_2_id = UserMessageId::new();
3676 thread
3677 .update(cx, |thread, cx| {
3678 thread.send(message_2_id.clone(), ["Message 2"], cx)
3679 })
3680 .unwrap();
3681 cx.run_until_parked();
3682 fake_model.send_last_completion_stream_text_chunk("Response 2");
3683 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3684 language_model::TokenUsage {
3685 input_tokens: 250,
3686 output_tokens: 75,
3687 cache_creation_input_tokens: 0,
3688 cache_read_input_tokens: 0,
3689 },
3690 ));
3691 fake_model.end_last_completion_stream();
3692 cx.run_until_parked();
3693
3694 // Verify initial state
3695 thread.read_with(cx, |thread, _| {
3696 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3697 });
3698
3699 // Truncate at message 2 (removes message 2 and everything after)
3700 thread
3701 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3702 .unwrap();
3703 cx.run_until_parked();
3704
3705 // After truncation, message_2_id no longer exists, so lookup should return None
3706 thread.read_with(cx, |thread, _| {
3707 assert_eq!(
3708 thread.tokens_before_message(&message_2_id),
3709 None,
3710 "After truncation, message 2 no longer exists"
3711 );
3712 // Message 1 still exists but has no tokens before it
3713 assert_eq!(
3714 thread.tokens_before_message(&message_1_id),
3715 None,
3716 "First message still has no tokens before it"
3717 );
3718 });
3719}
3720
3721#[gpui::test]
3722async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3723 init_test(cx);
3724
3725 let fs = FakeFs::new(cx.executor());
3726 fs.insert_tree("/root", json!({})).await;
3727 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3728
3729 // Test 1: Deny rule blocks command
3730 {
3731 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3732 let environment = Rc::new(FakeThreadEnvironment {
3733 handle: handle.clone(),
3734 });
3735
3736 cx.update(|cx| {
3737 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3738 settings.tool_permissions.tools.insert(
3739 "terminal".into(),
3740 agent_settings::ToolRules {
3741 default_mode: settings::ToolPermissionMode::Confirm,
3742 always_allow: vec![],
3743 always_deny: vec![
3744 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3745 ],
3746 always_confirm: vec![],
3747 invalid_patterns: vec![],
3748 },
3749 );
3750 agent_settings::AgentSettings::override_global(settings, cx);
3751 });
3752
3753 #[allow(clippy::arc_with_non_send_sync)]
3754 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3755 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3756
3757 let task = cx.update(|cx| {
3758 tool.run(
3759 crate::TerminalToolInput {
3760 command: "rm -rf /".to_string(),
3761 cd: ".".to_string(),
3762 timeout_ms: None,
3763 },
3764 event_stream,
3765 cx,
3766 )
3767 });
3768
3769 let result = task.await;
3770 assert!(
3771 result.is_err(),
3772 "expected command to be blocked by deny rule"
3773 );
3774 assert!(
3775 result.unwrap_err().to_string().contains("blocked"),
3776 "error should mention the command was blocked"
3777 );
3778 }
3779
3780 // Test 2: Allow rule skips confirmation (and overrides default_mode: Deny)
3781 {
3782 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3783 let environment = Rc::new(FakeThreadEnvironment {
3784 handle: handle.clone(),
3785 });
3786
3787 cx.update(|cx| {
3788 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3789 settings.always_allow_tool_actions = false;
3790 settings.tool_permissions.tools.insert(
3791 "terminal".into(),
3792 agent_settings::ToolRules {
3793 default_mode: settings::ToolPermissionMode::Deny,
3794 always_allow: vec![
3795 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
3796 ],
3797 always_deny: vec![],
3798 always_confirm: vec![],
3799 invalid_patterns: vec![],
3800 },
3801 );
3802 agent_settings::AgentSettings::override_global(settings, cx);
3803 });
3804
3805 #[allow(clippy::arc_with_non_send_sync)]
3806 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3807 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3808
3809 let task = cx.update(|cx| {
3810 tool.run(
3811 crate::TerminalToolInput {
3812 command: "echo hello".to_string(),
3813 cd: ".".to_string(),
3814 timeout_ms: None,
3815 },
3816 event_stream,
3817 cx,
3818 )
3819 });
3820
3821 let update = rx.expect_update_fields().await;
3822 assert!(
3823 update.content.iter().any(|blocks| {
3824 blocks
3825 .iter()
3826 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
3827 }),
3828 "expected terminal content (allow rule should skip confirmation and override default deny)"
3829 );
3830
3831 let result = task.await;
3832 assert!(
3833 result.is_ok(),
3834 "expected command to succeed without confirmation"
3835 );
3836 }
3837
3838 // Test 3: Confirm rule forces confirmation even with always_allow_tool_actions=true
3839 {
3840 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3841 let environment = Rc::new(FakeThreadEnvironment {
3842 handle: handle.clone(),
3843 });
3844
3845 cx.update(|cx| {
3846 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3847 settings.always_allow_tool_actions = true;
3848 settings.tool_permissions.tools.insert(
3849 "terminal".into(),
3850 agent_settings::ToolRules {
3851 default_mode: settings::ToolPermissionMode::Allow,
3852 always_allow: vec![],
3853 always_deny: vec![],
3854 always_confirm: vec![
3855 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
3856 ],
3857 invalid_patterns: vec![],
3858 },
3859 );
3860 agent_settings::AgentSettings::override_global(settings, cx);
3861 });
3862
3863 #[allow(clippy::arc_with_non_send_sync)]
3864 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3865 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3866
3867 let _task = cx.update(|cx| {
3868 tool.run(
3869 crate::TerminalToolInput {
3870 command: "sudo rm file".to_string(),
3871 cd: ".".to_string(),
3872 timeout_ms: None,
3873 },
3874 event_stream,
3875 cx,
3876 )
3877 });
3878
3879 let auth = rx.expect_authorization().await;
3880 assert!(
3881 auth.tool_call.fields.title.is_some(),
3882 "expected authorization request for sudo command despite always_allow_tool_actions=true"
3883 );
3884 }
3885
3886 // Test 4: default_mode: Deny blocks commands when no pattern matches
3887 {
3888 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3889 let environment = Rc::new(FakeThreadEnvironment {
3890 handle: handle.clone(),
3891 });
3892
3893 cx.update(|cx| {
3894 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3895 settings.always_allow_tool_actions = true;
3896 settings.tool_permissions.tools.insert(
3897 "terminal".into(),
3898 agent_settings::ToolRules {
3899 default_mode: settings::ToolPermissionMode::Deny,
3900 always_allow: vec![],
3901 always_deny: vec![],
3902 always_confirm: vec![],
3903 invalid_patterns: vec![],
3904 },
3905 );
3906 agent_settings::AgentSettings::override_global(settings, cx);
3907 });
3908
3909 #[allow(clippy::arc_with_non_send_sync)]
3910 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3911 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3912
3913 let task = cx.update(|cx| {
3914 tool.run(
3915 crate::TerminalToolInput {
3916 command: "echo hello".to_string(),
3917 cd: ".".to_string(),
3918 timeout_ms: None,
3919 },
3920 event_stream,
3921 cx,
3922 )
3923 });
3924
3925 let result = task.await;
3926 assert!(
3927 result.is_err(),
3928 "expected command to be blocked by default_mode: Deny"
3929 );
3930 assert!(
3931 result.unwrap_err().to_string().contains("disabled"),
3932 "error should mention the tool is disabled"
3933 );
3934 }
3935}
3936
3937#[gpui::test]
3938async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
3939 init_test(cx);
3940
3941 cx.update(|cx| {
3942 cx.update_flags(true, vec!["subagents".to_string()]);
3943 });
3944
3945 let fs = FakeFs::new(cx.executor());
3946 fs.insert_tree(path!("/test"), json!({})).await;
3947 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3948 let project_context = cx.new(|_cx| ProjectContext::default());
3949 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3950 let context_server_registry =
3951 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3952 let model = Arc::new(FakeLanguageModel::default());
3953
3954 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3955 let environment = Rc::new(FakeThreadEnvironment { handle });
3956
3957 let thread = cx.new(|cx| {
3958 let mut thread = Thread::new(
3959 project.clone(),
3960 project_context,
3961 context_server_registry,
3962 Templates::new(),
3963 Some(model),
3964 cx,
3965 );
3966 thread.add_default_tools(environment, cx);
3967 thread
3968 });
3969
3970 thread.read_with(cx, |thread, _| {
3971 assert!(
3972 thread.has_registered_tool("subagent"),
3973 "subagent tool should be present when feature flag is enabled"
3974 );
3975 });
3976}
3977
3978#[gpui::test]
3979async fn test_subagent_thread_inherits_parent_model(cx: &mut TestAppContext) {
3980 init_test(cx);
3981
3982 cx.update(|cx| {
3983 cx.update_flags(true, vec!["subagents".to_string()]);
3984 });
3985
3986 let fs = FakeFs::new(cx.executor());
3987 fs.insert_tree(path!("/test"), json!({})).await;
3988 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3989 let project_context = cx.new(|_cx| ProjectContext::default());
3990 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3991 let context_server_registry =
3992 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3993 let model = Arc::new(FakeLanguageModel::default());
3994
3995 let subagent_context = SubagentContext {
3996 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3997 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3998 depth: 1,
3999 summary_prompt: "Summarize".to_string(),
4000 context_low_prompt: "Context low".to_string(),
4001 };
4002
4003 let subagent = cx.new(|cx| {
4004 Thread::new_subagent(
4005 project.clone(),
4006 project_context,
4007 context_server_registry,
4008 Templates::new(),
4009 model.clone(),
4010 subagent_context,
4011 std::collections::BTreeMap::new(),
4012 cx,
4013 )
4014 });
4015
4016 subagent.read_with(cx, |thread, _| {
4017 assert!(thread.is_subagent());
4018 assert_eq!(thread.depth(), 1);
4019 assert!(thread.model().is_some());
4020 });
4021}
4022
4023#[gpui::test]
4024async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4025 init_test(cx);
4026
4027 cx.update(|cx| {
4028 cx.update_flags(true, vec!["subagents".to_string()]);
4029 });
4030
4031 let fs = FakeFs::new(cx.executor());
4032 fs.insert_tree(path!("/test"), json!({})).await;
4033 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4034 let project_context = cx.new(|_cx| ProjectContext::default());
4035 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4036 let context_server_registry =
4037 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4038 let model = Arc::new(FakeLanguageModel::default());
4039
4040 let subagent_context = SubagentContext {
4041 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4042 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4043 depth: MAX_SUBAGENT_DEPTH,
4044 summary_prompt: "Summarize".to_string(),
4045 context_low_prompt: "Context low".to_string(),
4046 };
4047
4048 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
4049 let environment = Rc::new(FakeThreadEnvironment { handle });
4050
4051 let deep_subagent = cx.new(|cx| {
4052 let mut thread = Thread::new_subagent(
4053 project.clone(),
4054 project_context,
4055 context_server_registry,
4056 Templates::new(),
4057 model.clone(),
4058 subagent_context,
4059 std::collections::BTreeMap::new(),
4060 cx,
4061 );
4062 thread.add_default_tools(environment, cx);
4063 thread
4064 });
4065
4066 deep_subagent.read_with(cx, |thread, _| {
4067 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
4068 assert!(
4069 !thread.has_registered_tool("subagent"),
4070 "subagent tool should not be present at max depth"
4071 );
4072 });
4073}
4074
4075#[gpui::test]
4076async fn test_subagent_receives_task_prompt(cx: &mut TestAppContext) {
4077 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4078 let fake_model = model.as_fake();
4079
4080 cx.update(|cx| {
4081 cx.update_flags(true, vec!["subagents".to_string()]);
4082 });
4083
4084 let subagent_context = SubagentContext {
4085 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4086 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4087 depth: 1,
4088 summary_prompt: "Summarize your work".to_string(),
4089 context_low_prompt: "Context low, wrap up".to_string(),
4090 };
4091
4092 let project = thread.read_with(cx, |t, _| t.project.clone());
4093 let project_context = cx.new(|_cx| ProjectContext::default());
4094 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4095 let context_server_registry =
4096 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4097
4098 let subagent = cx.new(|cx| {
4099 Thread::new_subagent(
4100 project.clone(),
4101 project_context,
4102 context_server_registry,
4103 Templates::new(),
4104 model.clone(),
4105 subagent_context,
4106 std::collections::BTreeMap::new(),
4107 cx,
4108 )
4109 });
4110
4111 let task_prompt = "Find all TODO comments in the codebase";
4112 subagent
4113 .update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))
4114 .unwrap();
4115 cx.run_until_parked();
4116
4117 let pending = fake_model.pending_completions();
4118 assert_eq!(pending.len(), 1, "should have one pending completion");
4119
4120 let messages = &pending[0].messages;
4121 let user_messages: Vec<_> = messages
4122 .iter()
4123 .filter(|m| m.role == language_model::Role::User)
4124 .collect();
4125 assert_eq!(user_messages.len(), 1, "should have one user message");
4126
4127 let content = &user_messages[0].content[0];
4128 assert!(
4129 content.to_str().unwrap().contains("TODO"),
4130 "task prompt should be in user message"
4131 );
4132}
4133
4134#[gpui::test]
4135async fn test_subagent_returns_summary_on_completion(cx: &mut TestAppContext) {
4136 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4137 let fake_model = model.as_fake();
4138
4139 cx.update(|cx| {
4140 cx.update_flags(true, vec!["subagents".to_string()]);
4141 });
4142
4143 let subagent_context = SubagentContext {
4144 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4145 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4146 depth: 1,
4147 summary_prompt: "Please summarize what you found".to_string(),
4148 context_low_prompt: "Context low, wrap up".to_string(),
4149 };
4150
4151 let project = thread.read_with(cx, |t, _| t.project.clone());
4152 let project_context = cx.new(|_cx| ProjectContext::default());
4153 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4154 let context_server_registry =
4155 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4156
4157 let subagent = cx.new(|cx| {
4158 Thread::new_subagent(
4159 project.clone(),
4160 project_context,
4161 context_server_registry,
4162 Templates::new(),
4163 model.clone(),
4164 subagent_context,
4165 std::collections::BTreeMap::new(),
4166 cx,
4167 )
4168 });
4169
4170 subagent
4171 .update(cx, |thread, cx| {
4172 thread.submit_user_message("Do some work", cx)
4173 })
4174 .unwrap();
4175 cx.run_until_parked();
4176
4177 fake_model.send_last_completion_stream_text_chunk("I did the work");
4178 fake_model
4179 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4180 fake_model.end_last_completion_stream();
4181 cx.run_until_parked();
4182
4183 subagent
4184 .update(cx, |thread, cx| thread.request_final_summary(cx))
4185 .unwrap();
4186 cx.run_until_parked();
4187
4188 let pending = fake_model.pending_completions();
4189 assert!(
4190 !pending.is_empty(),
4191 "should have pending completion for summary"
4192 );
4193
4194 let messages = &pending.last().unwrap().messages;
4195 let user_messages: Vec<_> = messages
4196 .iter()
4197 .filter(|m| m.role == language_model::Role::User)
4198 .collect();
4199
4200 let last_user = user_messages.last().unwrap();
4201 assert!(
4202 last_user.content[0].to_str().unwrap().contains("summarize"),
4203 "summary prompt should be sent"
4204 );
4205}
4206
4207#[gpui::test]
4208async fn test_allowed_tools_restricts_subagent_capabilities(cx: &mut TestAppContext) {
4209 init_test(cx);
4210
4211 cx.update(|cx| {
4212 cx.update_flags(true, vec!["subagents".to_string()]);
4213 });
4214
4215 let fs = FakeFs::new(cx.executor());
4216 fs.insert_tree(path!("/test"), json!({})).await;
4217 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4218 let project_context = cx.new(|_cx| ProjectContext::default());
4219 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4220 let context_server_registry =
4221 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4222 let model = Arc::new(FakeLanguageModel::default());
4223
4224 let subagent_context = SubagentContext {
4225 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4226 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4227 depth: 1,
4228 summary_prompt: "Summarize".to_string(),
4229 context_low_prompt: "Context low".to_string(),
4230 };
4231
4232 let subagent = cx.new(|cx| {
4233 let mut thread = Thread::new_subagent(
4234 project.clone(),
4235 project_context,
4236 context_server_registry,
4237 Templates::new(),
4238 model.clone(),
4239 subagent_context,
4240 std::collections::BTreeMap::new(),
4241 cx,
4242 );
4243 thread.add_tool(EchoTool);
4244 thread.add_tool(DelayTool);
4245 thread.add_tool(WordListTool);
4246 thread
4247 });
4248
4249 subagent.read_with(cx, |thread, _| {
4250 assert!(thread.has_registered_tool("echo"));
4251 assert!(thread.has_registered_tool("delay"));
4252 assert!(thread.has_registered_tool("word_list"));
4253 });
4254
4255 let allowed: collections::HashSet<gpui::SharedString> =
4256 vec!["echo".into()].into_iter().collect();
4257
4258 subagent.update(cx, |thread, _cx| {
4259 thread.restrict_tools(&allowed);
4260 });
4261
4262 subagent.read_with(cx, |thread, _| {
4263 assert!(
4264 thread.has_registered_tool("echo"),
4265 "echo should still be available"
4266 );
4267 assert!(
4268 !thread.has_registered_tool("delay"),
4269 "delay should be removed"
4270 );
4271 assert!(
4272 !thread.has_registered_tool("word_list"),
4273 "word_list should be removed"
4274 );
4275 });
4276}
4277
4278#[gpui::test]
4279async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4280 init_test(cx);
4281
4282 cx.update(|cx| {
4283 cx.update_flags(true, vec!["subagents".to_string()]);
4284 });
4285
4286 let fs = FakeFs::new(cx.executor());
4287 fs.insert_tree(path!("/test"), json!({})).await;
4288 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4289 let project_context = cx.new(|_cx| ProjectContext::default());
4290 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4291 let context_server_registry =
4292 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4293 let model = Arc::new(FakeLanguageModel::default());
4294
4295 let parent = cx.new(|cx| {
4296 Thread::new(
4297 project.clone(),
4298 project_context.clone(),
4299 context_server_registry.clone(),
4300 Templates::new(),
4301 Some(model.clone()),
4302 cx,
4303 )
4304 });
4305
4306 let subagent_context = SubagentContext {
4307 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4308 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4309 depth: 1,
4310 summary_prompt: "Summarize".to_string(),
4311 context_low_prompt: "Context low".to_string(),
4312 };
4313
4314 let subagent = cx.new(|cx| {
4315 Thread::new_subagent(
4316 project.clone(),
4317 project_context.clone(),
4318 context_server_registry.clone(),
4319 Templates::new(),
4320 model.clone(),
4321 subagent_context,
4322 std::collections::BTreeMap::new(),
4323 cx,
4324 )
4325 });
4326
4327 parent.update(cx, |thread, _cx| {
4328 thread.register_running_subagent(subagent.downgrade());
4329 });
4330
4331 subagent
4332 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4333 .unwrap();
4334 cx.run_until_parked();
4335
4336 subagent.read_with(cx, |thread, _| {
4337 assert!(!thread.is_turn_complete(), "subagent should be running");
4338 });
4339
4340 parent.update(cx, |thread, cx| {
4341 thread.cancel(cx).detach();
4342 });
4343
4344 subagent.read_with(cx, |thread, _| {
4345 assert!(
4346 thread.is_turn_complete(),
4347 "subagent should be cancelled when parent cancels"
4348 );
4349 });
4350}
4351
4352#[gpui::test]
4353async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4354 // This test verifies that the subagent tool properly handles user cancellation
4355 // via `event_stream.cancelled_by_user()` and stops all running subagents.
4356 init_test(cx);
4357 always_allow_tools(cx);
4358
4359 cx.update(|cx| {
4360 cx.update_flags(true, vec!["subagents".to_string()]);
4361 });
4362
4363 let fs = FakeFs::new(cx.executor());
4364 fs.insert_tree(path!("/test"), json!({})).await;
4365 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4366 let project_context = cx.new(|_cx| ProjectContext::default());
4367 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4368 let context_server_registry =
4369 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4370 let model = Arc::new(FakeLanguageModel::default());
4371
4372 let parent = cx.new(|cx| {
4373 Thread::new(
4374 project.clone(),
4375 project_context.clone(),
4376 context_server_registry.clone(),
4377 Templates::new(),
4378 Some(model.clone()),
4379 cx,
4380 )
4381 });
4382
4383 let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
4384 std::collections::BTreeMap::new();
4385
4386 #[allow(clippy::arc_with_non_send_sync)]
4387 let tool = Arc::new(SubagentTool::new(
4388 parent.downgrade(),
4389 project.clone(),
4390 project_context,
4391 context_server_registry,
4392 Templates::new(),
4393 0,
4394 parent_tools,
4395 ));
4396
4397 let (event_stream, _rx, mut cancellation_tx) =
4398 crate::ToolCallEventStream::test_with_cancellation();
4399
4400 // Start the subagent tool
4401 let task = cx.update(|cx| {
4402 tool.run(
4403 SubagentToolInput {
4404 subagents: vec![crate::SubagentConfig {
4405 label: "Long running task".to_string(),
4406 task_prompt: "Do a very long task that takes forever".to_string(),
4407 summary_prompt: "Summarize".to_string(),
4408 context_low_prompt: "Context low".to_string(),
4409 timeout_ms: None,
4410 allowed_tools: None,
4411 }],
4412 },
4413 event_stream.clone(),
4414 cx,
4415 )
4416 });
4417
4418 cx.run_until_parked();
4419
4420 // Signal cancellation via the event stream
4421 crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4422
4423 // The task should complete promptly with a cancellation error
4424 let timeout = cx.background_executor.timer(Duration::from_secs(5));
4425 let result = futures::select! {
4426 result = task.fuse() => result,
4427 _ = timeout.fuse() => {
4428 panic!("subagent tool did not respond to cancellation within timeout");
4429 }
4430 };
4431
4432 // Verify we got a cancellation error
4433 let err = result.unwrap_err();
4434 assert!(
4435 err.to_string().contains("cancelled by user"),
4436 "expected cancellation error, got: {}",
4437 err
4438 );
4439}
4440
4441#[gpui::test]
4442async fn test_subagent_model_error_returned_as_tool_error(cx: &mut TestAppContext) {
4443 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4444 let fake_model = model.as_fake();
4445
4446 cx.update(|cx| {
4447 cx.update_flags(true, vec!["subagents".to_string()]);
4448 });
4449
4450 let subagent_context = SubagentContext {
4451 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4452 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4453 depth: 1,
4454 summary_prompt: "Summarize".to_string(),
4455 context_low_prompt: "Context low".to_string(),
4456 };
4457
4458 let project = thread.read_with(cx, |t, _| t.project.clone());
4459 let project_context = cx.new(|_cx| ProjectContext::default());
4460 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4461 let context_server_registry =
4462 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4463
4464 let subagent = cx.new(|cx| {
4465 Thread::new_subagent(
4466 project.clone(),
4467 project_context,
4468 context_server_registry,
4469 Templates::new(),
4470 model.clone(),
4471 subagent_context,
4472 std::collections::BTreeMap::new(),
4473 cx,
4474 )
4475 });
4476
4477 subagent
4478 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4479 .unwrap();
4480 cx.run_until_parked();
4481
4482 subagent.read_with(cx, |thread, _| {
4483 assert!(!thread.is_turn_complete(), "turn should be in progress");
4484 });
4485
4486 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::NoApiKey {
4487 provider: LanguageModelProviderName::from("Fake".to_string()),
4488 });
4489 fake_model.end_last_completion_stream();
4490 cx.run_until_parked();
4491
4492 subagent.read_with(cx, |thread, _| {
4493 assert!(
4494 thread.is_turn_complete(),
4495 "turn should be complete after non-retryable error"
4496 );
4497 });
4498}
4499
4500#[gpui::test]
4501async fn test_subagent_timeout_triggers_early_summary(cx: &mut TestAppContext) {
4502 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4503 let fake_model = model.as_fake();
4504
4505 cx.update(|cx| {
4506 cx.update_flags(true, vec!["subagents".to_string()]);
4507 });
4508
4509 let subagent_context = SubagentContext {
4510 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4511 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4512 depth: 1,
4513 summary_prompt: "Summarize your work".to_string(),
4514 context_low_prompt: "Context low, stop and summarize".to_string(),
4515 };
4516
4517 let project = thread.read_with(cx, |t, _| t.project.clone());
4518 let project_context = cx.new(|_cx| ProjectContext::default());
4519 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4520 let context_server_registry =
4521 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4522
4523 let subagent = cx.new(|cx| {
4524 Thread::new_subagent(
4525 project.clone(),
4526 project_context.clone(),
4527 context_server_registry.clone(),
4528 Templates::new(),
4529 model.clone(),
4530 subagent_context.clone(),
4531 std::collections::BTreeMap::new(),
4532 cx,
4533 )
4534 });
4535
4536 subagent.update(cx, |thread, _| {
4537 thread.add_tool(EchoTool);
4538 });
4539
4540 subagent
4541 .update(cx, |thread, cx| {
4542 thread.submit_user_message("Do some work", cx)
4543 })
4544 .unwrap();
4545 cx.run_until_parked();
4546
4547 fake_model.send_last_completion_stream_text_chunk("Working on it...");
4548 fake_model
4549 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4550 fake_model.end_last_completion_stream();
4551 cx.run_until_parked();
4552
4553 let interrupt_result = subagent.update(cx, |thread, cx| thread.interrupt_for_summary(cx));
4554 assert!(
4555 interrupt_result.is_ok(),
4556 "interrupt_for_summary should succeed"
4557 );
4558
4559 cx.run_until_parked();
4560
4561 let pending = fake_model.pending_completions();
4562 assert!(
4563 !pending.is_empty(),
4564 "should have pending completion for interrupted summary"
4565 );
4566
4567 let messages = &pending.last().unwrap().messages;
4568 let user_messages: Vec<_> = messages
4569 .iter()
4570 .filter(|m| m.role == language_model::Role::User)
4571 .collect();
4572
4573 let last_user = user_messages.last().unwrap();
4574 let content_str = last_user.content[0].to_str().unwrap();
4575 assert!(
4576 content_str.contains("Context low") || content_str.contains("stop and summarize"),
4577 "context_low_prompt should be sent when interrupting: got {:?}",
4578 content_str
4579 );
4580}
4581
4582#[gpui::test]
4583async fn test_context_low_check_returns_true_when_usage_high(cx: &mut TestAppContext) {
4584 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4585 let fake_model = model.as_fake();
4586
4587 cx.update(|cx| {
4588 cx.update_flags(true, vec!["subagents".to_string()]);
4589 });
4590
4591 let subagent_context = SubagentContext {
4592 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4593 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4594 depth: 1,
4595 summary_prompt: "Summarize".to_string(),
4596 context_low_prompt: "Context low".to_string(),
4597 };
4598
4599 let project = thread.read_with(cx, |t, _| t.project.clone());
4600 let project_context = cx.new(|_cx| ProjectContext::default());
4601 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4602 let context_server_registry =
4603 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4604
4605 let subagent = cx.new(|cx| {
4606 Thread::new_subagent(
4607 project.clone(),
4608 project_context,
4609 context_server_registry,
4610 Templates::new(),
4611 model.clone(),
4612 subagent_context,
4613 std::collections::BTreeMap::new(),
4614 cx,
4615 )
4616 });
4617
4618 subagent
4619 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4620 .unwrap();
4621 cx.run_until_parked();
4622
4623 let max_tokens = model.max_token_count();
4624 let high_usage = language_model::TokenUsage {
4625 input_tokens: (max_tokens as f64 * 0.80) as u64,
4626 output_tokens: 0,
4627 cache_creation_input_tokens: 0,
4628 cache_read_input_tokens: 0,
4629 };
4630
4631 fake_model
4632 .send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(high_usage));
4633 fake_model.send_last_completion_stream_text_chunk("Working...");
4634 fake_model
4635 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4636 fake_model.end_last_completion_stream();
4637 cx.run_until_parked();
4638
4639 let usage = subagent.read_with(cx, |thread, _| thread.latest_token_usage());
4640 assert!(usage.is_some(), "should have token usage after completion");
4641
4642 let usage = usage.unwrap();
4643 let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
4644 assert!(
4645 remaining_ratio <= 0.25,
4646 "remaining ratio should be at or below 25% (got {}%), indicating context is low",
4647 remaining_ratio * 100.0
4648 );
4649}
4650
4651#[gpui::test]
4652async fn test_allowed_tools_rejects_unknown_tool(cx: &mut TestAppContext) {
4653 init_test(cx);
4654
4655 cx.update(|cx| {
4656 cx.update_flags(true, vec!["subagents".to_string()]);
4657 });
4658
4659 let fs = FakeFs::new(cx.executor());
4660 fs.insert_tree(path!("/test"), json!({})).await;
4661 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4662 let project_context = cx.new(|_cx| ProjectContext::default());
4663 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4664 let context_server_registry =
4665 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4666 let model = Arc::new(FakeLanguageModel::default());
4667
4668 let parent = cx.new(|cx| {
4669 let mut thread = Thread::new(
4670 project.clone(),
4671 project_context.clone(),
4672 context_server_registry.clone(),
4673 Templates::new(),
4674 Some(model.clone()),
4675 cx,
4676 );
4677 thread.add_tool(EchoTool);
4678 thread
4679 });
4680
4681 let mut parent_tools: std::collections::BTreeMap<
4682 gpui::SharedString,
4683 Arc<dyn crate::AnyAgentTool>,
4684 > = std::collections::BTreeMap::new();
4685 parent_tools.insert("echo".into(), EchoTool.erase());
4686
4687 #[allow(clippy::arc_with_non_send_sync)]
4688 let tool = Arc::new(SubagentTool::new(
4689 parent.downgrade(),
4690 project,
4691 project_context,
4692 context_server_registry,
4693 Templates::new(),
4694 0,
4695 parent_tools,
4696 ));
4697
4698 let subagent_configs = vec![crate::SubagentConfig {
4699 label: "Test".to_string(),
4700 task_prompt: "Do something".to_string(),
4701 summary_prompt: "Summarize".to_string(),
4702 context_low_prompt: "Context low".to_string(),
4703 timeout_ms: None,
4704 allowed_tools: Some(vec!["nonexistent_tool".to_string()]),
4705 }];
4706 let result = tool.validate_subagents(&subagent_configs);
4707 assert!(result.is_err(), "should reject unknown tool");
4708 let err_msg = result.unwrap_err().to_string();
4709 assert!(
4710 err_msg.contains("nonexistent_tool"),
4711 "error should mention the invalid tool name: {}",
4712 err_msg
4713 );
4714 assert!(
4715 err_msg.contains("do not exist"),
4716 "error should explain the tool does not exist: {}",
4717 err_msg
4718 );
4719}
4720
4721#[gpui::test]
4722async fn test_subagent_empty_response_handled(cx: &mut TestAppContext) {
4723 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4724 let fake_model = model.as_fake();
4725
4726 cx.update(|cx| {
4727 cx.update_flags(true, vec!["subagents".to_string()]);
4728 });
4729
4730 let subagent_context = SubagentContext {
4731 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4732 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4733 depth: 1,
4734 summary_prompt: "Summarize".to_string(),
4735 context_low_prompt: "Context low".to_string(),
4736 };
4737
4738 let project = thread.read_with(cx, |t, _| t.project.clone());
4739 let project_context = cx.new(|_cx| ProjectContext::default());
4740 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4741 let context_server_registry =
4742 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4743
4744 let subagent = cx.new(|cx| {
4745 Thread::new_subagent(
4746 project.clone(),
4747 project_context,
4748 context_server_registry,
4749 Templates::new(),
4750 model.clone(),
4751 subagent_context,
4752 std::collections::BTreeMap::new(),
4753 cx,
4754 )
4755 });
4756
4757 subagent
4758 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4759 .unwrap();
4760 cx.run_until_parked();
4761
4762 fake_model
4763 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4764 fake_model.end_last_completion_stream();
4765 cx.run_until_parked();
4766
4767 subagent.read_with(cx, |thread, _| {
4768 assert!(
4769 thread.is_turn_complete(),
4770 "turn should complete even with empty response"
4771 );
4772 });
4773}
4774
4775#[gpui::test]
4776async fn test_nested_subagent_at_depth_2_succeeds(cx: &mut TestAppContext) {
4777 init_test(cx);
4778
4779 cx.update(|cx| {
4780 cx.update_flags(true, vec!["subagents".to_string()]);
4781 });
4782
4783 let fs = FakeFs::new(cx.executor());
4784 fs.insert_tree(path!("/test"), json!({})).await;
4785 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4786 let project_context = cx.new(|_cx| ProjectContext::default());
4787 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4788 let context_server_registry =
4789 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4790 let model = Arc::new(FakeLanguageModel::default());
4791
4792 let depth_1_context = SubagentContext {
4793 parent_thread_id: agent_client_protocol::SessionId::new("root-id"),
4794 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-1"),
4795 depth: 1,
4796 summary_prompt: "Summarize".to_string(),
4797 context_low_prompt: "Context low".to_string(),
4798 };
4799
4800 let depth_1_subagent = cx.new(|cx| {
4801 Thread::new_subagent(
4802 project.clone(),
4803 project_context.clone(),
4804 context_server_registry.clone(),
4805 Templates::new(),
4806 model.clone(),
4807 depth_1_context,
4808 std::collections::BTreeMap::new(),
4809 cx,
4810 )
4811 });
4812
4813 depth_1_subagent.read_with(cx, |thread, _| {
4814 assert_eq!(thread.depth(), 1);
4815 assert!(thread.is_subagent());
4816 });
4817
4818 let depth_2_context = SubagentContext {
4819 parent_thread_id: agent_client_protocol::SessionId::new("depth-1-id"),
4820 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-2"),
4821 depth: 2,
4822 summary_prompt: "Summarize depth 2".to_string(),
4823 context_low_prompt: "Context low depth 2".to_string(),
4824 };
4825
4826 let depth_2_subagent = cx.new(|cx| {
4827 Thread::new_subagent(
4828 project.clone(),
4829 project_context.clone(),
4830 context_server_registry.clone(),
4831 Templates::new(),
4832 model.clone(),
4833 depth_2_context,
4834 std::collections::BTreeMap::new(),
4835 cx,
4836 )
4837 });
4838
4839 depth_2_subagent.read_with(cx, |thread, _| {
4840 assert_eq!(thread.depth(), 2);
4841 assert!(thread.is_subagent());
4842 });
4843
4844 depth_2_subagent
4845 .update(cx, |thread, cx| {
4846 thread.submit_user_message("Nested task", cx)
4847 })
4848 .unwrap();
4849 cx.run_until_parked();
4850
4851 let pending = model.as_fake().pending_completions();
4852 assert!(
4853 !pending.is_empty(),
4854 "depth-2 subagent should be able to submit messages"
4855 );
4856}
4857
4858#[gpui::test]
4859async fn test_subagent_uses_tool_and_returns_result(cx: &mut TestAppContext) {
4860 init_test(cx);
4861 always_allow_tools(cx);
4862
4863 cx.update(|cx| {
4864 cx.update_flags(true, vec!["subagents".to_string()]);
4865 });
4866
4867 let fs = FakeFs::new(cx.executor());
4868 fs.insert_tree(path!("/test"), json!({})).await;
4869 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4870 let project_context = cx.new(|_cx| ProjectContext::default());
4871 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4872 let context_server_registry =
4873 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4874 let model = Arc::new(FakeLanguageModel::default());
4875 let fake_model = model.as_fake();
4876
4877 let subagent_context = SubagentContext {
4878 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4879 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4880 depth: 1,
4881 summary_prompt: "Summarize what you did".to_string(),
4882 context_low_prompt: "Context low".to_string(),
4883 };
4884
4885 let subagent = cx.new(|cx| {
4886 let mut thread = Thread::new_subagent(
4887 project.clone(),
4888 project_context,
4889 context_server_registry,
4890 Templates::new(),
4891 model.clone(),
4892 subagent_context,
4893 std::collections::BTreeMap::new(),
4894 cx,
4895 );
4896 thread.add_tool(EchoTool);
4897 thread
4898 });
4899
4900 subagent.read_with(cx, |thread, _| {
4901 assert!(
4902 thread.has_registered_tool("echo"),
4903 "subagent should have echo tool"
4904 );
4905 });
4906
4907 subagent
4908 .update(cx, |thread, cx| {
4909 thread.submit_user_message("Use the echo tool to echo 'hello world'", cx)
4910 })
4911 .unwrap();
4912 cx.run_until_parked();
4913
4914 let tool_use = LanguageModelToolUse {
4915 id: "tool_call_1".into(),
4916 name: EchoTool::name().into(),
4917 raw_input: json!({"text": "hello world"}).to_string(),
4918 input: json!({"text": "hello world"}),
4919 is_input_complete: true,
4920 thought_signature: None,
4921 };
4922 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
4923 fake_model.end_last_completion_stream();
4924 cx.run_until_parked();
4925
4926 let pending = fake_model.pending_completions();
4927 assert!(
4928 !pending.is_empty(),
4929 "should have pending completion after tool use"
4930 );
4931
4932 let last_completion = pending.last().unwrap();
4933 let has_tool_result = last_completion.messages.iter().any(|m| {
4934 m.content
4935 .iter()
4936 .any(|c| matches!(c, MessageContent::ToolResult(_)))
4937 });
4938 assert!(
4939 has_tool_result,
4940 "tool result should be in the messages sent back to the model"
4941 );
4942}
4943
4944#[gpui::test]
4945async fn test_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4946 init_test(cx);
4947
4948 cx.update(|cx| {
4949 cx.update_flags(true, vec!["subagents".to_string()]);
4950 });
4951
4952 let fs = FakeFs::new(cx.executor());
4953 fs.insert_tree(path!("/test"), json!({})).await;
4954 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4955 let project_context = cx.new(|_cx| ProjectContext::default());
4956 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4957 let context_server_registry =
4958 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4959 let model = Arc::new(FakeLanguageModel::default());
4960
4961 let parent = cx.new(|cx| {
4962 Thread::new(
4963 project.clone(),
4964 project_context.clone(),
4965 context_server_registry.clone(),
4966 Templates::new(),
4967 Some(model.clone()),
4968 cx,
4969 )
4970 });
4971
4972 let mut subagents = Vec::new();
4973 for i in 0..MAX_PARALLEL_SUBAGENTS {
4974 let subagent_context = SubagentContext {
4975 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4976 tool_use_id: language_model::LanguageModelToolUseId::from(format!("tool-use-{}", i)),
4977 depth: 1,
4978 summary_prompt: "Summarize".to_string(),
4979 context_low_prompt: "Context low".to_string(),
4980 };
4981
4982 let subagent = cx.new(|cx| {
4983 Thread::new_subagent(
4984 project.clone(),
4985 project_context.clone(),
4986 context_server_registry.clone(),
4987 Templates::new(),
4988 model.clone(),
4989 subagent_context,
4990 std::collections::BTreeMap::new(),
4991 cx,
4992 )
4993 });
4994
4995 parent.update(cx, |thread, _cx| {
4996 thread.register_running_subagent(subagent.downgrade());
4997 });
4998 subagents.push(subagent);
4999 }
5000
5001 parent.read_with(cx, |thread, _| {
5002 assert_eq!(
5003 thread.running_subagent_count(),
5004 MAX_PARALLEL_SUBAGENTS,
5005 "should have MAX_PARALLEL_SUBAGENTS registered"
5006 );
5007 });
5008
5009 let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
5010 std::collections::BTreeMap::new();
5011
5012 #[allow(clippy::arc_with_non_send_sync)]
5013 let tool = Arc::new(SubagentTool::new(
5014 parent.downgrade(),
5015 project.clone(),
5016 project_context,
5017 context_server_registry,
5018 Templates::new(),
5019 0,
5020 parent_tools,
5021 ));
5022
5023 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5024
5025 let result = cx.update(|cx| {
5026 tool.run(
5027 SubagentToolInput {
5028 subagents: vec![crate::SubagentConfig {
5029 label: "Test".to_string(),
5030 task_prompt: "Do something".to_string(),
5031 summary_prompt: "Summarize".to_string(),
5032 context_low_prompt: "Context low".to_string(),
5033 timeout_ms: None,
5034 allowed_tools: None,
5035 }],
5036 },
5037 event_stream,
5038 cx,
5039 )
5040 });
5041
5042 let err = result.await.unwrap_err();
5043 assert!(
5044 err.to_string().contains("Maximum parallel subagents"),
5045 "should reject when max parallel subagents reached: {}",
5046 err
5047 );
5048
5049 drop(subagents);
5050}
5051
5052#[gpui::test]
5053async fn test_subagent_tool_end_to_end(cx: &mut TestAppContext) {
5054 init_test(cx);
5055 always_allow_tools(cx);
5056
5057 cx.update(|cx| {
5058 cx.update_flags(true, vec!["subagents".to_string()]);
5059 });
5060
5061 let fs = FakeFs::new(cx.executor());
5062 fs.insert_tree(path!("/test"), json!({})).await;
5063 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
5064 let project_context = cx.new(|_cx| ProjectContext::default());
5065 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5066 let context_server_registry =
5067 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5068 let model = Arc::new(FakeLanguageModel::default());
5069 let fake_model = model.as_fake();
5070
5071 let parent = cx.new(|cx| {
5072 let mut thread = Thread::new(
5073 project.clone(),
5074 project_context.clone(),
5075 context_server_registry.clone(),
5076 Templates::new(),
5077 Some(model.clone()),
5078 cx,
5079 );
5080 thread.add_tool(EchoTool);
5081 thread
5082 });
5083
5084 let mut parent_tools: std::collections::BTreeMap<
5085 gpui::SharedString,
5086 Arc<dyn crate::AnyAgentTool>,
5087 > = std::collections::BTreeMap::new();
5088 parent_tools.insert("echo".into(), EchoTool.erase());
5089
5090 #[allow(clippy::arc_with_non_send_sync)]
5091 let tool = Arc::new(SubagentTool::new(
5092 parent.downgrade(),
5093 project.clone(),
5094 project_context,
5095 context_server_registry,
5096 Templates::new(),
5097 0,
5098 parent_tools,
5099 ));
5100
5101 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5102
5103 let task = cx.update(|cx| {
5104 tool.run(
5105 SubagentToolInput {
5106 subagents: vec![crate::SubagentConfig {
5107 label: "Research task".to_string(),
5108 task_prompt: "Find all TODOs in the codebase".to_string(),
5109 summary_prompt: "Summarize what you found".to_string(),
5110 context_low_prompt: "Context low, wrap up".to_string(),
5111 timeout_ms: None,
5112 allowed_tools: None,
5113 }],
5114 },
5115 event_stream,
5116 cx,
5117 )
5118 });
5119
5120 cx.run_until_parked();
5121
5122 let pending = fake_model.pending_completions();
5123 assert!(
5124 !pending.is_empty(),
5125 "subagent should have started and sent a completion request"
5126 );
5127
5128 let first_completion = &pending[0];
5129 let has_task_prompt = first_completion.messages.iter().any(|m| {
5130 m.role == language_model::Role::User
5131 && m.content
5132 .iter()
5133 .any(|c| c.to_str().map(|s| s.contains("TODO")).unwrap_or(false))
5134 });
5135 assert!(has_task_prompt, "task prompt should be sent to subagent");
5136
5137 fake_model.send_last_completion_stream_text_chunk("I found 5 TODOs in the codebase.");
5138 fake_model
5139 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
5140 fake_model.end_last_completion_stream();
5141 cx.run_until_parked();
5142
5143 let pending = fake_model.pending_completions();
5144 assert!(
5145 !pending.is_empty(),
5146 "should have pending completion for summary request"
5147 );
5148
5149 let last_completion = pending.last().unwrap();
5150 let has_summary_prompt = last_completion.messages.iter().any(|m| {
5151 m.role == language_model::Role::User
5152 && m.content.iter().any(|c| {
5153 c.to_str()
5154 .map(|s| s.contains("Summarize") || s.contains("summarize"))
5155 .unwrap_or(false)
5156 })
5157 });
5158 assert!(
5159 has_summary_prompt,
5160 "summary prompt should be sent after task completion"
5161 );
5162
5163 fake_model.send_last_completion_stream_text_chunk("Summary: Found 5 TODOs across 3 files.");
5164 fake_model
5165 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
5166 fake_model.end_last_completion_stream();
5167 cx.run_until_parked();
5168
5169 let result = task.await;
5170 assert!(result.is_ok(), "subagent tool should complete successfully");
5171
5172 let summary = result.unwrap();
5173 assert!(
5174 summary.contains("Summary") || summary.contains("TODO") || summary.contains("5"),
5175 "summary should contain subagent's response: {}",
5176 summary
5177 );
5178}
5179
5180#[gpui::test]
5181async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5182 init_test(cx);
5183
5184 let fs = FakeFs::new(cx.executor());
5185 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5186 .await;
5187 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5188
5189 cx.update(|cx| {
5190 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5191 settings.tool_permissions.tools.insert(
5192 "edit_file".into(),
5193 agent_settings::ToolRules {
5194 default_mode: settings::ToolPermissionMode::Allow,
5195 always_allow: vec![],
5196 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5197 always_confirm: vec![],
5198 invalid_patterns: vec![],
5199 },
5200 );
5201 agent_settings::AgentSettings::override_global(settings, cx);
5202 });
5203
5204 let context_server_registry =
5205 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5206 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5207 let templates = crate::Templates::new();
5208 let thread = cx.new(|cx| {
5209 crate::Thread::new(
5210 project.clone(),
5211 cx.new(|_cx| prompt_store::ProjectContext::default()),
5212 context_server_registry,
5213 templates.clone(),
5214 None,
5215 cx,
5216 )
5217 });
5218
5219 #[allow(clippy::arc_with_non_send_sync)]
5220 let tool = Arc::new(crate::EditFileTool::new(
5221 project.clone(),
5222 thread.downgrade(),
5223 language_registry,
5224 templates,
5225 ));
5226 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5227
5228 let task = cx.update(|cx| {
5229 tool.run(
5230 crate::EditFileToolInput {
5231 display_description: "Edit sensitive file".to_string(),
5232 path: "root/sensitive_config.txt".into(),
5233 mode: crate::EditFileMode::Edit,
5234 },
5235 event_stream,
5236 cx,
5237 )
5238 });
5239
5240 let result = task.await;
5241 assert!(result.is_err(), "expected edit to be blocked");
5242 assert!(
5243 result.unwrap_err().to_string().contains("blocked"),
5244 "error should mention the edit was blocked"
5245 );
5246}
5247
5248#[gpui::test]
5249async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5250 init_test(cx);
5251
5252 let fs = FakeFs::new(cx.executor());
5253 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5254 .await;
5255 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5256
5257 cx.update(|cx| {
5258 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5259 settings.tool_permissions.tools.insert(
5260 "delete_path".into(),
5261 agent_settings::ToolRules {
5262 default_mode: settings::ToolPermissionMode::Allow,
5263 always_allow: vec![],
5264 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5265 always_confirm: vec![],
5266 invalid_patterns: vec![],
5267 },
5268 );
5269 agent_settings::AgentSettings::override_global(settings, cx);
5270 });
5271
5272 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5273
5274 #[allow(clippy::arc_with_non_send_sync)]
5275 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5276 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5277
5278 let task = cx.update(|cx| {
5279 tool.run(
5280 crate::DeletePathToolInput {
5281 path: "root/important_data.txt".to_string(),
5282 },
5283 event_stream,
5284 cx,
5285 )
5286 });
5287
5288 let result = task.await;
5289 assert!(result.is_err(), "expected deletion to be blocked");
5290 assert!(
5291 result.unwrap_err().to_string().contains("blocked"),
5292 "error should mention the deletion was blocked"
5293 );
5294}
5295
5296#[gpui::test]
5297async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5298 init_test(cx);
5299
5300 let fs = FakeFs::new(cx.executor());
5301 fs.insert_tree(
5302 "/root",
5303 json!({
5304 "safe.txt": "content",
5305 "protected": {}
5306 }),
5307 )
5308 .await;
5309 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5310
5311 cx.update(|cx| {
5312 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5313 settings.tool_permissions.tools.insert(
5314 "move_path".into(),
5315 agent_settings::ToolRules {
5316 default_mode: settings::ToolPermissionMode::Allow,
5317 always_allow: vec![],
5318 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5319 always_confirm: vec![],
5320 invalid_patterns: vec![],
5321 },
5322 );
5323 agent_settings::AgentSettings::override_global(settings, cx);
5324 });
5325
5326 #[allow(clippy::arc_with_non_send_sync)]
5327 let tool = Arc::new(crate::MovePathTool::new(project));
5328 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5329
5330 let task = cx.update(|cx| {
5331 tool.run(
5332 crate::MovePathToolInput {
5333 source_path: "root/safe.txt".to_string(),
5334 destination_path: "root/protected/safe.txt".to_string(),
5335 },
5336 event_stream,
5337 cx,
5338 )
5339 });
5340
5341 let result = task.await;
5342 assert!(
5343 result.is_err(),
5344 "expected move to be blocked due to destination path"
5345 );
5346 assert!(
5347 result.unwrap_err().to_string().contains("blocked"),
5348 "error should mention the move was blocked"
5349 );
5350}
5351
5352#[gpui::test]
5353async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5354 init_test(cx);
5355
5356 let fs = FakeFs::new(cx.executor());
5357 fs.insert_tree(
5358 "/root",
5359 json!({
5360 "secret.txt": "secret content",
5361 "public": {}
5362 }),
5363 )
5364 .await;
5365 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5366
5367 cx.update(|cx| {
5368 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5369 settings.tool_permissions.tools.insert(
5370 "move_path".into(),
5371 agent_settings::ToolRules {
5372 default_mode: settings::ToolPermissionMode::Allow,
5373 always_allow: vec![],
5374 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5375 always_confirm: vec![],
5376 invalid_patterns: vec![],
5377 },
5378 );
5379 agent_settings::AgentSettings::override_global(settings, cx);
5380 });
5381
5382 #[allow(clippy::arc_with_non_send_sync)]
5383 let tool = Arc::new(crate::MovePathTool::new(project));
5384 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5385
5386 let task = cx.update(|cx| {
5387 tool.run(
5388 crate::MovePathToolInput {
5389 source_path: "root/secret.txt".to_string(),
5390 destination_path: "root/public/not_secret.txt".to_string(),
5391 },
5392 event_stream,
5393 cx,
5394 )
5395 });
5396
5397 let result = task.await;
5398 assert!(
5399 result.is_err(),
5400 "expected move to be blocked due to source path"
5401 );
5402 assert!(
5403 result.unwrap_err().to_string().contains("blocked"),
5404 "error should mention the move was blocked"
5405 );
5406}
5407
5408#[gpui::test]
5409async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5410 init_test(cx);
5411
5412 let fs = FakeFs::new(cx.executor());
5413 fs.insert_tree(
5414 "/root",
5415 json!({
5416 "confidential.txt": "confidential data",
5417 "dest": {}
5418 }),
5419 )
5420 .await;
5421 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5422
5423 cx.update(|cx| {
5424 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5425 settings.tool_permissions.tools.insert(
5426 "copy_path".into(),
5427 agent_settings::ToolRules {
5428 default_mode: settings::ToolPermissionMode::Allow,
5429 always_allow: vec![],
5430 always_deny: vec![
5431 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5432 ],
5433 always_confirm: vec![],
5434 invalid_patterns: vec![],
5435 },
5436 );
5437 agent_settings::AgentSettings::override_global(settings, cx);
5438 });
5439
5440 #[allow(clippy::arc_with_non_send_sync)]
5441 let tool = Arc::new(crate::CopyPathTool::new(project));
5442 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5443
5444 let task = cx.update(|cx| {
5445 tool.run(
5446 crate::CopyPathToolInput {
5447 source_path: "root/confidential.txt".to_string(),
5448 destination_path: "root/dest/copy.txt".to_string(),
5449 },
5450 event_stream,
5451 cx,
5452 )
5453 });
5454
5455 let result = task.await;
5456 assert!(result.is_err(), "expected copy to be blocked");
5457 assert!(
5458 result.unwrap_err().to_string().contains("blocked"),
5459 "error should mention the copy was blocked"
5460 );
5461}
5462
5463#[gpui::test]
5464async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5465 init_test(cx);
5466
5467 let fs = FakeFs::new(cx.executor());
5468 fs.insert_tree(
5469 "/root",
5470 json!({
5471 "normal.txt": "normal content",
5472 "readonly": {
5473 "config.txt": "readonly content"
5474 }
5475 }),
5476 )
5477 .await;
5478 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5479
5480 cx.update(|cx| {
5481 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5482 settings.tool_permissions.tools.insert(
5483 "save_file".into(),
5484 agent_settings::ToolRules {
5485 default_mode: settings::ToolPermissionMode::Allow,
5486 always_allow: vec![],
5487 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5488 always_confirm: vec![],
5489 invalid_patterns: vec![],
5490 },
5491 );
5492 agent_settings::AgentSettings::override_global(settings, cx);
5493 });
5494
5495 #[allow(clippy::arc_with_non_send_sync)]
5496 let tool = Arc::new(crate::SaveFileTool::new(project));
5497 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5498
5499 let task = cx.update(|cx| {
5500 tool.run(
5501 crate::SaveFileToolInput {
5502 paths: vec![
5503 std::path::PathBuf::from("root/normal.txt"),
5504 std::path::PathBuf::from("root/readonly/config.txt"),
5505 ],
5506 },
5507 event_stream,
5508 cx,
5509 )
5510 });
5511
5512 let result = task.await;
5513 assert!(
5514 result.is_err(),
5515 "expected save to be blocked due to denied path"
5516 );
5517 assert!(
5518 result.unwrap_err().to_string().contains("blocked"),
5519 "error should mention the save was blocked"
5520 );
5521}
5522
5523#[gpui::test]
5524async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5525 init_test(cx);
5526
5527 let fs = FakeFs::new(cx.executor());
5528 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5529 .await;
5530 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5531
5532 cx.update(|cx| {
5533 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5534 settings.always_allow_tool_actions = false;
5535 settings.tool_permissions.tools.insert(
5536 "save_file".into(),
5537 agent_settings::ToolRules {
5538 default_mode: settings::ToolPermissionMode::Allow,
5539 always_allow: vec![],
5540 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5541 always_confirm: vec![],
5542 invalid_patterns: vec![],
5543 },
5544 );
5545 agent_settings::AgentSettings::override_global(settings, cx);
5546 });
5547
5548 #[allow(clippy::arc_with_non_send_sync)]
5549 let tool = Arc::new(crate::SaveFileTool::new(project));
5550 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5551
5552 let task = cx.update(|cx| {
5553 tool.run(
5554 crate::SaveFileToolInput {
5555 paths: vec![std::path::PathBuf::from("root/config.secret")],
5556 },
5557 event_stream,
5558 cx,
5559 )
5560 });
5561
5562 let result = task.await;
5563 assert!(result.is_err(), "expected save to be blocked");
5564 assert!(
5565 result.unwrap_err().to_string().contains("blocked"),
5566 "error should mention the save was blocked"
5567 );
5568}
5569
5570#[gpui::test]
5571async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5572 init_test(cx);
5573
5574 cx.update(|cx| {
5575 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5576 settings.tool_permissions.tools.insert(
5577 "web_search".into(),
5578 agent_settings::ToolRules {
5579 default_mode: settings::ToolPermissionMode::Allow,
5580 always_allow: vec![],
5581 always_deny: vec![
5582 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5583 ],
5584 always_confirm: vec![],
5585 invalid_patterns: vec![],
5586 },
5587 );
5588 agent_settings::AgentSettings::override_global(settings, cx);
5589 });
5590
5591 #[allow(clippy::arc_with_non_send_sync)]
5592 let tool = Arc::new(crate::WebSearchTool);
5593 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5594
5595 let input: crate::WebSearchToolInput =
5596 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5597
5598 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5599
5600 let result = task.await;
5601 assert!(result.is_err(), "expected search to be blocked");
5602 assert!(
5603 result.unwrap_err().to_string().contains("blocked"),
5604 "error should mention the search was blocked"
5605 );
5606}
5607
5608#[gpui::test]
5609async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5610 init_test(cx);
5611
5612 let fs = FakeFs::new(cx.executor());
5613 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5614 .await;
5615 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5616
5617 cx.update(|cx| {
5618 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5619 settings.always_allow_tool_actions = false;
5620 settings.tool_permissions.tools.insert(
5621 "edit_file".into(),
5622 agent_settings::ToolRules {
5623 default_mode: settings::ToolPermissionMode::Confirm,
5624 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5625 always_deny: vec![],
5626 always_confirm: vec![],
5627 invalid_patterns: vec![],
5628 },
5629 );
5630 agent_settings::AgentSettings::override_global(settings, cx);
5631 });
5632
5633 let context_server_registry =
5634 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5635 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5636 let templates = crate::Templates::new();
5637 let thread = cx.new(|cx| {
5638 crate::Thread::new(
5639 project.clone(),
5640 cx.new(|_cx| prompt_store::ProjectContext::default()),
5641 context_server_registry,
5642 templates.clone(),
5643 None,
5644 cx,
5645 )
5646 });
5647
5648 #[allow(clippy::arc_with_non_send_sync)]
5649 let tool = Arc::new(crate::EditFileTool::new(
5650 project,
5651 thread.downgrade(),
5652 language_registry,
5653 templates,
5654 ));
5655 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5656
5657 let _task = cx.update(|cx| {
5658 tool.run(
5659 crate::EditFileToolInput {
5660 display_description: "Edit README".to_string(),
5661 path: "root/README.md".into(),
5662 mode: crate::EditFileMode::Edit,
5663 },
5664 event_stream,
5665 cx,
5666 )
5667 });
5668
5669 cx.run_until_parked();
5670
5671 let event = rx.try_next();
5672 assert!(
5673 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5674 "expected no authorization request for allowed .md file"
5675 );
5676}
5677
5678#[gpui::test]
5679async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5680 init_test(cx);
5681
5682 cx.update(|cx| {
5683 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5684 settings.tool_permissions.tools.insert(
5685 "fetch".into(),
5686 agent_settings::ToolRules {
5687 default_mode: settings::ToolPermissionMode::Allow,
5688 always_allow: vec![],
5689 always_deny: vec![
5690 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5691 ],
5692 always_confirm: vec![],
5693 invalid_patterns: vec![],
5694 },
5695 );
5696 agent_settings::AgentSettings::override_global(settings, cx);
5697 });
5698
5699 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5700
5701 #[allow(clippy::arc_with_non_send_sync)]
5702 let tool = Arc::new(crate::FetchTool::new(http_client));
5703 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5704
5705 let input: crate::FetchToolInput =
5706 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5707
5708 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5709
5710 let result = task.await;
5711 assert!(result.is_err(), "expected fetch to be blocked");
5712 assert!(
5713 result.unwrap_err().to_string().contains("blocked"),
5714 "error should mention the fetch was blocked"
5715 );
5716}
5717
5718#[gpui::test]
5719async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5720 init_test(cx);
5721
5722 cx.update(|cx| {
5723 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5724 settings.always_allow_tool_actions = false;
5725 settings.tool_permissions.tools.insert(
5726 "fetch".into(),
5727 agent_settings::ToolRules {
5728 default_mode: settings::ToolPermissionMode::Confirm,
5729 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5730 always_deny: vec![],
5731 always_confirm: vec![],
5732 invalid_patterns: vec![],
5733 },
5734 );
5735 agent_settings::AgentSettings::override_global(settings, cx);
5736 });
5737
5738 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5739
5740 #[allow(clippy::arc_with_non_send_sync)]
5741 let tool = Arc::new(crate::FetchTool::new(http_client));
5742 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5743
5744 let input: crate::FetchToolInput =
5745 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5746
5747 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5748
5749 cx.run_until_parked();
5750
5751 let event = rx.try_next();
5752 assert!(
5753 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5754 "expected no authorization request for allowed docs.rs URL"
5755 );
5756}