1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, UserMessageId,
4};
5use agent_client_protocol::{self as acp};
6use agent_settings::AgentProfileId;
7use anyhow::Result;
8use client::{Client, UserStore};
9use cloud_llm_client::CompletionIntent;
10use collections::IndexMap;
11use context_server::{ContextServer, ContextServerCommand, ContextServerId};
12use feature_flags::FeatureFlagAppExt as _;
13use fs::{FakeFs, Fs};
14use futures::{
15 FutureExt as _, StreamExt,
16 channel::{
17 mpsc::{self, UnboundedReceiver},
18 oneshot,
19 },
20 future::{Fuse, Shared},
21};
22use gpui::{
23 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
24 http_client::FakeHttpClient,
25};
26use indoc::indoc;
27use language_model::{
28 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
29 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
30 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
31 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
32};
33use pretty_assertions::assert_eq;
34use project::{
35 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
36};
37use prompt_store::ProjectContext;
38use reqwest_client::ReqwestClient;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use serde_json::json;
42use settings::{Settings, SettingsStore};
43use std::{
44 path::Path,
45 pin::Pin,
46 rc::Rc,
47 sync::{
48 Arc,
49 atomic::{AtomicBool, Ordering},
50 },
51 time::Duration,
52};
53use util::path;
54
55mod edit_file_thread_test;
56mod test_tools;
57use test_tools::*;
58
59fn init_test(cx: &mut TestAppContext) {
60 cx.update(|cx| {
61 let settings_store = SettingsStore::test(cx);
62 cx.set_global(settings_store);
63 });
64}
65
66struct FakeTerminalHandle {
67 killed: Arc<AtomicBool>,
68 stopped_by_user: Arc<AtomicBool>,
69 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
70 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
71 output: acp::TerminalOutputResponse,
72 id: acp::TerminalId,
73}
74
75impl FakeTerminalHandle {
76 fn new_never_exits(cx: &mut App) -> Self {
77 let killed = Arc::new(AtomicBool::new(false));
78 let stopped_by_user = Arc::new(AtomicBool::new(false));
79
80 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
81
82 let wait_for_exit = cx
83 .spawn(async move |_cx| {
84 // Wait for the exit signal (sent when kill() is called)
85 let _ = exit_receiver.await;
86 acp::TerminalExitStatus::new()
87 })
88 .shared();
89
90 Self {
91 killed,
92 stopped_by_user,
93 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
94 wait_for_exit,
95 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
96 id: acp::TerminalId::new("fake_terminal".to_string()),
97 }
98 }
99
100 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
101 let killed = Arc::new(AtomicBool::new(false));
102 let stopped_by_user = Arc::new(AtomicBool::new(false));
103 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
104
105 let wait_for_exit = cx
106 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
107 .shared();
108
109 Self {
110 killed,
111 stopped_by_user,
112 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
113 wait_for_exit,
114 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
115 id: acp::TerminalId::new("fake_terminal".to_string()),
116 }
117 }
118
119 fn was_killed(&self) -> bool {
120 self.killed.load(Ordering::SeqCst)
121 }
122
123 fn set_stopped_by_user(&self, stopped: bool) {
124 self.stopped_by_user.store(stopped, Ordering::SeqCst);
125 }
126
127 fn signal_exit(&self) {
128 if let Some(sender) = self.exit_sender.borrow_mut().take() {
129 let _ = sender.send(());
130 }
131 }
132}
133
134impl crate::TerminalHandle for FakeTerminalHandle {
135 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
136 Ok(self.id.clone())
137 }
138
139 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
140 Ok(self.output.clone())
141 }
142
143 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
144 Ok(self.wait_for_exit.clone())
145 }
146
147 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
148 self.killed.store(true, Ordering::SeqCst);
149 self.signal_exit();
150 Ok(())
151 }
152
153 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
154 Ok(self.stopped_by_user.load(Ordering::SeqCst))
155 }
156}
157
158struct FakeThreadEnvironment {
159 handle: Rc<FakeTerminalHandle>,
160}
161
162impl crate::ThreadEnvironment for FakeThreadEnvironment {
163 fn create_terminal(
164 &self,
165 _command: String,
166 _cwd: Option<std::path::PathBuf>,
167 _output_byte_limit: Option<u64>,
168 _cx: &mut AsyncApp,
169 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
170 Task::ready(Ok(self.handle.clone() as Rc<dyn crate::TerminalHandle>))
171 }
172}
173
174/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
175struct MultiTerminalEnvironment {
176 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
177}
178
179impl MultiTerminalEnvironment {
180 fn new() -> Self {
181 Self {
182 handles: std::cell::RefCell::new(Vec::new()),
183 }
184 }
185
186 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
187 self.handles.borrow().clone()
188 }
189}
190
191impl crate::ThreadEnvironment for MultiTerminalEnvironment {
192 fn create_terminal(
193 &self,
194 _command: String,
195 _cwd: Option<std::path::PathBuf>,
196 _output_byte_limit: Option<u64>,
197 cx: &mut AsyncApp,
198 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
199 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
200 self.handles.borrow_mut().push(handle.clone());
201 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
202 }
203}
204
205fn always_allow_tools(cx: &mut TestAppContext) {
206 cx.update(|cx| {
207 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
208 settings.always_allow_tool_actions = true;
209 agent_settings::AgentSettings::override_global(settings, cx);
210 });
211}
212
213#[gpui::test]
214async fn test_echo(cx: &mut TestAppContext) {
215 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
216 let fake_model = model.as_fake();
217
218 let events = thread
219 .update(cx, |thread, cx| {
220 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
221 })
222 .unwrap();
223 cx.run_until_parked();
224 fake_model.send_last_completion_stream_text_chunk("Hello");
225 fake_model
226 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
227 fake_model.end_last_completion_stream();
228
229 let events = events.collect().await;
230 thread.update(cx, |thread, _cx| {
231 assert_eq!(
232 thread.last_message().unwrap().to_markdown(),
233 indoc! {"
234 ## Assistant
235
236 Hello
237 "}
238 )
239 });
240 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
241}
242
243#[gpui::test]
244async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
245 init_test(cx);
246 always_allow_tools(cx);
247
248 let fs = FakeFs::new(cx.executor());
249 let project = Project::test(fs, [], cx).await;
250
251 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
252 let environment = Rc::new(FakeThreadEnvironment {
253 handle: handle.clone(),
254 });
255
256 #[allow(clippy::arc_with_non_send_sync)]
257 let tool = Arc::new(crate::TerminalTool::new(project, environment));
258 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
259
260 let task = cx.update(|cx| {
261 tool.run(
262 crate::TerminalToolInput {
263 command: "sleep 1000".to_string(),
264 cd: ".".to_string(),
265 timeout_ms: Some(5),
266 },
267 event_stream,
268 cx,
269 )
270 });
271
272 let update = rx.expect_update_fields().await;
273 assert!(
274 update.content.iter().any(|blocks| {
275 blocks
276 .iter()
277 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
278 }),
279 "expected tool call update to include terminal content"
280 );
281
282 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
283
284 let deadline = std::time::Instant::now() + Duration::from_millis(500);
285 loop {
286 if let Some(result) = task_future.as_mut().now_or_never() {
287 let result = result.expect("terminal tool task should complete");
288
289 assert!(
290 handle.was_killed(),
291 "expected terminal handle to be killed on timeout"
292 );
293 assert!(
294 result.contains("partial output"),
295 "expected result to include terminal output, got: {result}"
296 );
297 return;
298 }
299
300 if std::time::Instant::now() >= deadline {
301 panic!("timed out waiting for terminal tool task to complete");
302 }
303
304 cx.run_until_parked();
305 cx.background_executor.timer(Duration::from_millis(1)).await;
306 }
307}
308
309#[gpui::test]
310#[ignore]
311async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
312 init_test(cx);
313 always_allow_tools(cx);
314
315 let fs = FakeFs::new(cx.executor());
316 let project = Project::test(fs, [], cx).await;
317
318 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
319 let environment = Rc::new(FakeThreadEnvironment {
320 handle: handle.clone(),
321 });
322
323 #[allow(clippy::arc_with_non_send_sync)]
324 let tool = Arc::new(crate::TerminalTool::new(project, environment));
325 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
326
327 let _task = cx.update(|cx| {
328 tool.run(
329 crate::TerminalToolInput {
330 command: "sleep 1000".to_string(),
331 cd: ".".to_string(),
332 timeout_ms: None,
333 },
334 event_stream,
335 cx,
336 )
337 });
338
339 let update = rx.expect_update_fields().await;
340 assert!(
341 update.content.iter().any(|blocks| {
342 blocks
343 .iter()
344 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
345 }),
346 "expected tool call update to include terminal content"
347 );
348
349 cx.background_executor
350 .timer(Duration::from_millis(25))
351 .await;
352
353 assert!(
354 !handle.was_killed(),
355 "did not expect terminal handle to be killed without a timeout"
356 );
357}
358
359#[gpui::test]
360async fn test_thinking(cx: &mut TestAppContext) {
361 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
362 let fake_model = model.as_fake();
363
364 let events = thread
365 .update(cx, |thread, cx| {
366 thread.send(
367 UserMessageId::new(),
368 [indoc! {"
369 Testing:
370
371 Generate a thinking step where you just think the word 'Think',
372 and have your final answer be 'Hello'
373 "}],
374 cx,
375 )
376 })
377 .unwrap();
378 cx.run_until_parked();
379 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
380 text: "Think".to_string(),
381 signature: None,
382 });
383 fake_model.send_last_completion_stream_text_chunk("Hello");
384 fake_model
385 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
386 fake_model.end_last_completion_stream();
387
388 let events = events.collect().await;
389 thread.update(cx, |thread, _cx| {
390 assert_eq!(
391 thread.last_message().unwrap().to_markdown(),
392 indoc! {"
393 ## Assistant
394
395 <think>Think</think>
396 Hello
397 "}
398 )
399 });
400 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
401}
402
403#[gpui::test]
404async fn test_system_prompt(cx: &mut TestAppContext) {
405 let ThreadTest {
406 model,
407 thread,
408 project_context,
409 ..
410 } = setup(cx, TestModel::Fake).await;
411 let fake_model = model.as_fake();
412
413 project_context.update(cx, |project_context, _cx| {
414 project_context.shell = "test-shell".into()
415 });
416 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
417 thread
418 .update(cx, |thread, cx| {
419 thread.send(UserMessageId::new(), ["abc"], cx)
420 })
421 .unwrap();
422 cx.run_until_parked();
423 let mut pending_completions = fake_model.pending_completions();
424 assert_eq!(
425 pending_completions.len(),
426 1,
427 "unexpected pending completions: {:?}",
428 pending_completions
429 );
430
431 let pending_completion = pending_completions.pop().unwrap();
432 assert_eq!(pending_completion.messages[0].role, Role::System);
433
434 let system_message = &pending_completion.messages[0];
435 let system_prompt = system_message.content[0].to_str().unwrap();
436 assert!(
437 system_prompt.contains("test-shell"),
438 "unexpected system message: {:?}",
439 system_message
440 );
441 assert!(
442 system_prompt.contains("## Fixing Diagnostics"),
443 "unexpected system message: {:?}",
444 system_message
445 );
446}
447
448#[gpui::test]
449async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
450 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
451 let fake_model = model.as_fake();
452
453 thread
454 .update(cx, |thread, cx| {
455 thread.send(UserMessageId::new(), ["abc"], cx)
456 })
457 .unwrap();
458 cx.run_until_parked();
459 let mut pending_completions = fake_model.pending_completions();
460 assert_eq!(
461 pending_completions.len(),
462 1,
463 "unexpected pending completions: {:?}",
464 pending_completions
465 );
466
467 let pending_completion = pending_completions.pop().unwrap();
468 assert_eq!(pending_completion.messages[0].role, Role::System);
469
470 let system_message = &pending_completion.messages[0];
471 let system_prompt = system_message.content[0].to_str().unwrap();
472 assert!(
473 !system_prompt.contains("## Tool Use"),
474 "unexpected system message: {:?}",
475 system_message
476 );
477 assert!(
478 !system_prompt.contains("## Fixing Diagnostics"),
479 "unexpected system message: {:?}",
480 system_message
481 );
482}
483
484#[gpui::test]
485async fn test_prompt_caching(cx: &mut TestAppContext) {
486 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
487 let fake_model = model.as_fake();
488
489 // Send initial user message and verify it's cached
490 thread
491 .update(cx, |thread, cx| {
492 thread.send(UserMessageId::new(), ["Message 1"], cx)
493 })
494 .unwrap();
495 cx.run_until_parked();
496
497 let completion = fake_model.pending_completions().pop().unwrap();
498 assert_eq!(
499 completion.messages[1..],
500 vec![LanguageModelRequestMessage {
501 role: Role::User,
502 content: vec!["Message 1".into()],
503 cache: true,
504 reasoning_details: None,
505 }]
506 );
507 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
508 "Response to Message 1".into(),
509 ));
510 fake_model.end_last_completion_stream();
511 cx.run_until_parked();
512
513 // Send another user message and verify only the latest is cached
514 thread
515 .update(cx, |thread, cx| {
516 thread.send(UserMessageId::new(), ["Message 2"], cx)
517 })
518 .unwrap();
519 cx.run_until_parked();
520
521 let completion = fake_model.pending_completions().pop().unwrap();
522 assert_eq!(
523 completion.messages[1..],
524 vec![
525 LanguageModelRequestMessage {
526 role: Role::User,
527 content: vec!["Message 1".into()],
528 cache: false,
529 reasoning_details: None,
530 },
531 LanguageModelRequestMessage {
532 role: Role::Assistant,
533 content: vec!["Response to Message 1".into()],
534 cache: false,
535 reasoning_details: None,
536 },
537 LanguageModelRequestMessage {
538 role: Role::User,
539 content: vec!["Message 2".into()],
540 cache: true,
541 reasoning_details: None,
542 }
543 ]
544 );
545 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
546 "Response to Message 2".into(),
547 ));
548 fake_model.end_last_completion_stream();
549 cx.run_until_parked();
550
551 // Simulate a tool call and verify that the latest tool result is cached
552 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
553 thread
554 .update(cx, |thread, cx| {
555 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
556 })
557 .unwrap();
558 cx.run_until_parked();
559
560 let tool_use = LanguageModelToolUse {
561 id: "tool_1".into(),
562 name: EchoTool::NAME.into(),
563 raw_input: json!({"text": "test"}).to_string(),
564 input: json!({"text": "test"}),
565 is_input_complete: true,
566 thought_signature: None,
567 };
568 fake_model
569 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
570 fake_model.end_last_completion_stream();
571 cx.run_until_parked();
572
573 let completion = fake_model.pending_completions().pop().unwrap();
574 let tool_result = LanguageModelToolResult {
575 tool_use_id: "tool_1".into(),
576 tool_name: EchoTool::NAME.into(),
577 is_error: false,
578 content: "test".into(),
579 output: Some("test".into()),
580 };
581 assert_eq!(
582 completion.messages[1..],
583 vec![
584 LanguageModelRequestMessage {
585 role: Role::User,
586 content: vec!["Message 1".into()],
587 cache: false,
588 reasoning_details: None,
589 },
590 LanguageModelRequestMessage {
591 role: Role::Assistant,
592 content: vec!["Response to Message 1".into()],
593 cache: false,
594 reasoning_details: None,
595 },
596 LanguageModelRequestMessage {
597 role: Role::User,
598 content: vec!["Message 2".into()],
599 cache: false,
600 reasoning_details: None,
601 },
602 LanguageModelRequestMessage {
603 role: Role::Assistant,
604 content: vec!["Response to Message 2".into()],
605 cache: false,
606 reasoning_details: None,
607 },
608 LanguageModelRequestMessage {
609 role: Role::User,
610 content: vec!["Use the echo tool".into()],
611 cache: false,
612 reasoning_details: None,
613 },
614 LanguageModelRequestMessage {
615 role: Role::Assistant,
616 content: vec![MessageContent::ToolUse(tool_use)],
617 cache: false,
618 reasoning_details: None,
619 },
620 LanguageModelRequestMessage {
621 role: Role::User,
622 content: vec![MessageContent::ToolResult(tool_result)],
623 cache: true,
624 reasoning_details: None,
625 }
626 ]
627 );
628}
629
630#[gpui::test]
631#[cfg_attr(not(feature = "e2e"), ignore)]
632async fn test_basic_tool_calls(cx: &mut TestAppContext) {
633 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
634
635 // Test a tool call that's likely to complete *before* streaming stops.
636 let events = thread
637 .update(cx, |thread, cx| {
638 thread.add_tool(EchoTool);
639 thread.send(
640 UserMessageId::new(),
641 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
642 cx,
643 )
644 })
645 .unwrap()
646 .collect()
647 .await;
648 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
649
650 // Test a tool calls that's likely to complete *after* streaming stops.
651 let events = thread
652 .update(cx, |thread, cx| {
653 thread.remove_tool(&EchoTool::NAME);
654 thread.add_tool(DelayTool);
655 thread.send(
656 UserMessageId::new(),
657 [
658 "Now call the delay tool with 200ms.",
659 "When the timer goes off, then you echo the output of the tool.",
660 ],
661 cx,
662 )
663 })
664 .unwrap()
665 .collect()
666 .await;
667 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
668 thread.update(cx, |thread, _cx| {
669 assert!(
670 thread
671 .last_message()
672 .unwrap()
673 .as_agent_message()
674 .unwrap()
675 .content
676 .iter()
677 .any(|content| {
678 if let AgentMessageContent::Text(text) = content {
679 text.contains("Ding")
680 } else {
681 false
682 }
683 }),
684 "{}",
685 thread.to_markdown()
686 );
687 });
688}
689
690#[gpui::test]
691#[cfg_attr(not(feature = "e2e"), ignore)]
692async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
693 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
694
695 // Test a tool call that's likely to complete *before* streaming stops.
696 let mut events = thread
697 .update(cx, |thread, cx| {
698 thread.add_tool(WordListTool);
699 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
700 })
701 .unwrap();
702
703 let mut saw_partial_tool_use = false;
704 while let Some(event) = events.next().await {
705 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
706 thread.update(cx, |thread, _cx| {
707 // Look for a tool use in the thread's last message
708 let message = thread.last_message().unwrap();
709 let agent_message = message.as_agent_message().unwrap();
710 let last_content = agent_message.content.last().unwrap();
711 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
712 assert_eq!(last_tool_use.name.as_ref(), "word_list");
713 if tool_call.status == acp::ToolCallStatus::Pending {
714 if !last_tool_use.is_input_complete
715 && last_tool_use.input.get("g").is_none()
716 {
717 saw_partial_tool_use = true;
718 }
719 } else {
720 last_tool_use
721 .input
722 .get("a")
723 .expect("'a' has streamed because input is now complete");
724 last_tool_use
725 .input
726 .get("g")
727 .expect("'g' has streamed because input is now complete");
728 }
729 } else {
730 panic!("last content should be a tool use");
731 }
732 });
733 }
734 }
735
736 assert!(
737 saw_partial_tool_use,
738 "should see at least one partially streamed tool use in the history"
739 );
740}
741
742#[gpui::test]
743async fn test_tool_authorization(cx: &mut TestAppContext) {
744 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
745 let fake_model = model.as_fake();
746
747 let mut events = thread
748 .update(cx, |thread, cx| {
749 thread.add_tool(ToolRequiringPermission);
750 thread.send(UserMessageId::new(), ["abc"], cx)
751 })
752 .unwrap();
753 cx.run_until_parked();
754 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
755 LanguageModelToolUse {
756 id: "tool_id_1".into(),
757 name: ToolRequiringPermission::NAME.into(),
758 raw_input: "{}".into(),
759 input: json!({}),
760 is_input_complete: true,
761 thought_signature: None,
762 },
763 ));
764 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
765 LanguageModelToolUse {
766 id: "tool_id_2".into(),
767 name: ToolRequiringPermission::NAME.into(),
768 raw_input: "{}".into(),
769 input: json!({}),
770 is_input_complete: true,
771 thought_signature: None,
772 },
773 ));
774 fake_model.end_last_completion_stream();
775 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
776 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
777
778 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
779 tool_call_auth_1
780 .response
781 .send(acp::PermissionOptionId::new("allow"))
782 .unwrap();
783 cx.run_until_parked();
784
785 // Reject the second - send "deny" option_id directly since Deny is now a button
786 tool_call_auth_2
787 .response
788 .send(acp::PermissionOptionId::new("deny"))
789 .unwrap();
790 cx.run_until_parked();
791
792 let completion = fake_model.pending_completions().pop().unwrap();
793 let message = completion.messages.last().unwrap();
794 assert_eq!(
795 message.content,
796 vec![
797 language_model::MessageContent::ToolResult(LanguageModelToolResult {
798 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
799 tool_name: ToolRequiringPermission::NAME.into(),
800 is_error: false,
801 content: "Allowed".into(),
802 output: Some("Allowed".into())
803 }),
804 language_model::MessageContent::ToolResult(LanguageModelToolResult {
805 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
806 tool_name: ToolRequiringPermission::NAME.into(),
807 is_error: true,
808 content: "Permission to run tool denied by user".into(),
809 output: Some("Permission to run tool denied by user".into())
810 })
811 ]
812 );
813
814 // Simulate yet another tool call.
815 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
816 LanguageModelToolUse {
817 id: "tool_id_3".into(),
818 name: ToolRequiringPermission::NAME.into(),
819 raw_input: "{}".into(),
820 input: json!({}),
821 is_input_complete: true,
822 thought_signature: None,
823 },
824 ));
825 fake_model.end_last_completion_stream();
826
827 // Respond by always allowing tools - send transformed option_id
828 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
829 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
830 tool_call_auth_3
831 .response
832 .send(acp::PermissionOptionId::new(
833 "always_allow:tool_requiring_permission",
834 ))
835 .unwrap();
836 cx.run_until_parked();
837 let completion = fake_model.pending_completions().pop().unwrap();
838 let message = completion.messages.last().unwrap();
839 assert_eq!(
840 message.content,
841 vec![language_model::MessageContent::ToolResult(
842 LanguageModelToolResult {
843 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
844 tool_name: ToolRequiringPermission::NAME.into(),
845 is_error: false,
846 content: "Allowed".into(),
847 output: Some("Allowed".into())
848 }
849 )]
850 );
851
852 // Simulate a final tool call, ensuring we don't trigger authorization.
853 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
854 LanguageModelToolUse {
855 id: "tool_id_4".into(),
856 name: ToolRequiringPermission::NAME.into(),
857 raw_input: "{}".into(),
858 input: json!({}),
859 is_input_complete: true,
860 thought_signature: None,
861 },
862 ));
863 fake_model.end_last_completion_stream();
864 cx.run_until_parked();
865 let completion = fake_model.pending_completions().pop().unwrap();
866 let message = completion.messages.last().unwrap();
867 assert_eq!(
868 message.content,
869 vec![language_model::MessageContent::ToolResult(
870 LanguageModelToolResult {
871 tool_use_id: "tool_id_4".into(),
872 tool_name: ToolRequiringPermission::NAME.into(),
873 is_error: false,
874 content: "Allowed".into(),
875 output: Some("Allowed".into())
876 }
877 )]
878 );
879}
880
881#[gpui::test]
882async fn test_tool_hallucination(cx: &mut TestAppContext) {
883 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
884 let fake_model = model.as_fake();
885
886 let mut events = thread
887 .update(cx, |thread, cx| {
888 thread.send(UserMessageId::new(), ["abc"], cx)
889 })
890 .unwrap();
891 cx.run_until_parked();
892 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
893 LanguageModelToolUse {
894 id: "tool_id_1".into(),
895 name: "nonexistent_tool".into(),
896 raw_input: "{}".into(),
897 input: json!({}),
898 is_input_complete: true,
899 thought_signature: None,
900 },
901 ));
902 fake_model.end_last_completion_stream();
903
904 let tool_call = expect_tool_call(&mut events).await;
905 assert_eq!(tool_call.title, "nonexistent_tool");
906 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
907 let update = expect_tool_call_update_fields(&mut events).await;
908 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
909}
910
911async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
912 let event = events
913 .next()
914 .await
915 .expect("no tool call authorization event received")
916 .unwrap();
917 match event {
918 ThreadEvent::ToolCall(tool_call) => tool_call,
919 event => {
920 panic!("Unexpected event {event:?}");
921 }
922 }
923}
924
925async fn expect_tool_call_update_fields(
926 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
927) -> acp::ToolCallUpdate {
928 let event = events
929 .next()
930 .await
931 .expect("no tool call authorization event received")
932 .unwrap();
933 match event {
934 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
935 event => {
936 panic!("Unexpected event {event:?}");
937 }
938 }
939}
940
941async fn next_tool_call_authorization(
942 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
943) -> ToolCallAuthorization {
944 loop {
945 let event = events
946 .next()
947 .await
948 .expect("no tool call authorization event received")
949 .unwrap();
950 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
951 let permission_kinds = tool_call_authorization
952 .options
953 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
954 .map(|option| option.kind);
955 let allow_once = tool_call_authorization
956 .options
957 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
958 .map(|option| option.kind);
959
960 assert_eq!(
961 permission_kinds,
962 Some(acp::PermissionOptionKind::AllowAlways)
963 );
964 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
965 return tool_call_authorization;
966 }
967 }
968}
969
970#[test]
971fn test_permission_options_terminal_with_pattern() {
972 let permission_options =
973 ToolPermissionContext::new(TerminalTool::NAME, "cargo build --release")
974 .build_permission_options();
975
976 let PermissionOptions::Dropdown(choices) = permission_options else {
977 panic!("Expected dropdown permission options");
978 };
979
980 assert_eq!(choices.len(), 3);
981 let labels: Vec<&str> = choices
982 .iter()
983 .map(|choice| choice.allow.name.as_ref())
984 .collect();
985 assert!(labels.contains(&"Always for terminal"));
986 assert!(labels.contains(&"Always for `cargo` commands"));
987 assert!(labels.contains(&"Only this time"));
988}
989
990#[test]
991fn test_permission_options_edit_file_with_path_pattern() {
992 let permission_options =
993 ToolPermissionContext::new(EditFileTool::NAME, "src/main.rs").build_permission_options();
994
995 let PermissionOptions::Dropdown(choices) = permission_options else {
996 panic!("Expected dropdown permission options");
997 };
998
999 let labels: Vec<&str> = choices
1000 .iter()
1001 .map(|choice| choice.allow.name.as_ref())
1002 .collect();
1003 assert!(labels.contains(&"Always for edit file"));
1004 assert!(labels.contains(&"Always for `src/`"));
1005}
1006
1007#[test]
1008fn test_permission_options_fetch_with_domain_pattern() {
1009 let permission_options = ToolPermissionContext::new(FetchTool::NAME, "https://docs.rs/gpui")
1010 .build_permission_options();
1011
1012 let PermissionOptions::Dropdown(choices) = permission_options else {
1013 panic!("Expected dropdown permission options");
1014 };
1015
1016 let labels: Vec<&str> = choices
1017 .iter()
1018 .map(|choice| choice.allow.name.as_ref())
1019 .collect();
1020 assert!(labels.contains(&"Always for fetch"));
1021 assert!(labels.contains(&"Always for `docs.rs`"));
1022}
1023
1024#[test]
1025fn test_permission_options_without_pattern() {
1026 let permission_options =
1027 ToolPermissionContext::new(TerminalTool::NAME, "./deploy.sh --production")
1028 .build_permission_options();
1029
1030 let PermissionOptions::Dropdown(choices) = permission_options else {
1031 panic!("Expected dropdown permission options");
1032 };
1033
1034 assert_eq!(choices.len(), 2);
1035 let labels: Vec<&str> = choices
1036 .iter()
1037 .map(|choice| choice.allow.name.as_ref())
1038 .collect();
1039 assert!(labels.contains(&"Always for terminal"));
1040 assert!(labels.contains(&"Only this time"));
1041 assert!(!labels.iter().any(|label| label.contains("commands")));
1042}
1043
1044#[test]
1045fn test_permission_option_ids_for_terminal() {
1046 let permission_options =
1047 ToolPermissionContext::new(TerminalTool::NAME, "cargo build --release")
1048 .build_permission_options();
1049
1050 let PermissionOptions::Dropdown(choices) = permission_options else {
1051 panic!("Expected dropdown permission options");
1052 };
1053
1054 let allow_ids: Vec<String> = choices
1055 .iter()
1056 .map(|choice| choice.allow.option_id.0.to_string())
1057 .collect();
1058 let deny_ids: Vec<String> = choices
1059 .iter()
1060 .map(|choice| choice.deny.option_id.0.to_string())
1061 .collect();
1062
1063 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1064 assert!(allow_ids.contains(&"allow".to_string()));
1065 assert!(
1066 allow_ids
1067 .iter()
1068 .any(|id| id.starts_with("always_allow_pattern:terminal:")),
1069 "Missing allow pattern option"
1070 );
1071
1072 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1073 assert!(deny_ids.contains(&"deny".to_string()));
1074 assert!(
1075 deny_ids
1076 .iter()
1077 .any(|id| id.starts_with("always_deny_pattern:terminal:")),
1078 "Missing deny pattern option"
1079 );
1080}
1081
1082#[gpui::test]
1083#[cfg_attr(not(feature = "e2e"), ignore)]
1084async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1085 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1086
1087 // Test concurrent tool calls with different delay times
1088 let events = thread
1089 .update(cx, |thread, cx| {
1090 thread.add_tool(DelayTool);
1091 thread.send(
1092 UserMessageId::new(),
1093 [
1094 "Call the delay tool twice in the same message.",
1095 "Once with 100ms. Once with 300ms.",
1096 "When both timers are complete, describe the outputs.",
1097 ],
1098 cx,
1099 )
1100 })
1101 .unwrap()
1102 .collect()
1103 .await;
1104
1105 let stop_reasons = stop_events(events);
1106 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1107
1108 thread.update(cx, |thread, _cx| {
1109 let last_message = thread.last_message().unwrap();
1110 let agent_message = last_message.as_agent_message().unwrap();
1111 let text = agent_message
1112 .content
1113 .iter()
1114 .filter_map(|content| {
1115 if let AgentMessageContent::Text(text) = content {
1116 Some(text.as_str())
1117 } else {
1118 None
1119 }
1120 })
1121 .collect::<String>();
1122
1123 assert!(text.contains("Ding"));
1124 });
1125}
1126
1127#[gpui::test]
1128async fn test_profiles(cx: &mut TestAppContext) {
1129 let ThreadTest {
1130 model, thread, fs, ..
1131 } = setup(cx, TestModel::Fake).await;
1132 let fake_model = model.as_fake();
1133
1134 thread.update(cx, |thread, _cx| {
1135 thread.add_tool(DelayTool);
1136 thread.add_tool(EchoTool);
1137 thread.add_tool(InfiniteTool);
1138 });
1139
1140 // Override profiles and wait for settings to be loaded.
1141 fs.insert_file(
1142 paths::settings_file(),
1143 json!({
1144 "agent": {
1145 "profiles": {
1146 "test-1": {
1147 "name": "Test Profile 1",
1148 "tools": {
1149 EchoTool::NAME: true,
1150 DelayTool::NAME: true,
1151 }
1152 },
1153 "test-2": {
1154 "name": "Test Profile 2",
1155 "tools": {
1156 InfiniteTool::NAME: true,
1157 }
1158 }
1159 }
1160 }
1161 })
1162 .to_string()
1163 .into_bytes(),
1164 )
1165 .await;
1166 cx.run_until_parked();
1167
1168 // Test that test-1 profile (default) has echo and delay tools
1169 thread
1170 .update(cx, |thread, cx| {
1171 thread.set_profile(AgentProfileId("test-1".into()), cx);
1172 thread.send(UserMessageId::new(), ["test"], cx)
1173 })
1174 .unwrap();
1175 cx.run_until_parked();
1176
1177 let mut pending_completions = fake_model.pending_completions();
1178 assert_eq!(pending_completions.len(), 1);
1179 let completion = pending_completions.pop().unwrap();
1180 let tool_names: Vec<String> = completion
1181 .tools
1182 .iter()
1183 .map(|tool| tool.name.clone())
1184 .collect();
1185 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1186 fake_model.end_last_completion_stream();
1187
1188 // Switch to test-2 profile, and verify that it has only the infinite tool.
1189 thread
1190 .update(cx, |thread, cx| {
1191 thread.set_profile(AgentProfileId("test-2".into()), cx);
1192 thread.send(UserMessageId::new(), ["test2"], cx)
1193 })
1194 .unwrap();
1195 cx.run_until_parked();
1196 let mut pending_completions = fake_model.pending_completions();
1197 assert_eq!(pending_completions.len(), 1);
1198 let completion = pending_completions.pop().unwrap();
1199 let tool_names: Vec<String> = completion
1200 .tools
1201 .iter()
1202 .map(|tool| tool.name.clone())
1203 .collect();
1204 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1205}
1206
1207#[gpui::test]
1208async fn test_mcp_tools(cx: &mut TestAppContext) {
1209 let ThreadTest {
1210 model,
1211 thread,
1212 context_server_store,
1213 fs,
1214 ..
1215 } = setup(cx, TestModel::Fake).await;
1216 let fake_model = model.as_fake();
1217
1218 // Override profiles and wait for settings to be loaded.
1219 fs.insert_file(
1220 paths::settings_file(),
1221 json!({
1222 "agent": {
1223 "always_allow_tool_actions": true,
1224 "profiles": {
1225 "test": {
1226 "name": "Test Profile",
1227 "enable_all_context_servers": true,
1228 "tools": {
1229 EchoTool::NAME: true,
1230 }
1231 },
1232 }
1233 }
1234 })
1235 .to_string()
1236 .into_bytes(),
1237 )
1238 .await;
1239 cx.run_until_parked();
1240 thread.update(cx, |thread, cx| {
1241 thread.set_profile(AgentProfileId("test".into()), cx)
1242 });
1243
1244 let mut mcp_tool_calls = setup_context_server(
1245 "test_server",
1246 vec![context_server::types::Tool {
1247 name: "echo".into(),
1248 description: None,
1249 input_schema: serde_json::to_value(EchoTool::input_schema(
1250 LanguageModelToolSchemaFormat::JsonSchema,
1251 ))
1252 .unwrap(),
1253 output_schema: None,
1254 annotations: None,
1255 }],
1256 &context_server_store,
1257 cx,
1258 );
1259
1260 let events = thread.update(cx, |thread, cx| {
1261 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1262 });
1263 cx.run_until_parked();
1264
1265 // Simulate the model calling the MCP tool.
1266 let completion = fake_model.pending_completions().pop().unwrap();
1267 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1268 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1269 LanguageModelToolUse {
1270 id: "tool_1".into(),
1271 name: "echo".into(),
1272 raw_input: json!({"text": "test"}).to_string(),
1273 input: json!({"text": "test"}),
1274 is_input_complete: true,
1275 thought_signature: None,
1276 },
1277 ));
1278 fake_model.end_last_completion_stream();
1279 cx.run_until_parked();
1280
1281 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1282 assert_eq!(tool_call_params.name, "echo");
1283 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1284 tool_call_response
1285 .send(context_server::types::CallToolResponse {
1286 content: vec![context_server::types::ToolResponseContent::Text {
1287 text: "test".into(),
1288 }],
1289 is_error: None,
1290 meta: None,
1291 structured_content: None,
1292 })
1293 .unwrap();
1294 cx.run_until_parked();
1295
1296 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1297 fake_model.send_last_completion_stream_text_chunk("Done!");
1298 fake_model.end_last_completion_stream();
1299 events.collect::<Vec<_>>().await;
1300
1301 // Send again after adding the echo tool, ensuring the name collision is resolved.
1302 let events = thread.update(cx, |thread, cx| {
1303 thread.add_tool(EchoTool);
1304 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1305 });
1306 cx.run_until_parked();
1307 let completion = fake_model.pending_completions().pop().unwrap();
1308 assert_eq!(
1309 tool_names_for_completion(&completion),
1310 vec!["echo", "test_server_echo"]
1311 );
1312 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1313 LanguageModelToolUse {
1314 id: "tool_2".into(),
1315 name: "test_server_echo".into(),
1316 raw_input: json!({"text": "mcp"}).to_string(),
1317 input: json!({"text": "mcp"}),
1318 is_input_complete: true,
1319 thought_signature: None,
1320 },
1321 ));
1322 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1323 LanguageModelToolUse {
1324 id: "tool_3".into(),
1325 name: "echo".into(),
1326 raw_input: json!({"text": "native"}).to_string(),
1327 input: json!({"text": "native"}),
1328 is_input_complete: true,
1329 thought_signature: None,
1330 },
1331 ));
1332 fake_model.end_last_completion_stream();
1333 cx.run_until_parked();
1334
1335 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1336 assert_eq!(tool_call_params.name, "echo");
1337 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1338 tool_call_response
1339 .send(context_server::types::CallToolResponse {
1340 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1341 is_error: None,
1342 meta: None,
1343 structured_content: None,
1344 })
1345 .unwrap();
1346 cx.run_until_parked();
1347
1348 // Ensure the tool results were inserted with the correct names.
1349 let completion = fake_model.pending_completions().pop().unwrap();
1350 assert_eq!(
1351 completion.messages.last().unwrap().content,
1352 vec![
1353 MessageContent::ToolResult(LanguageModelToolResult {
1354 tool_use_id: "tool_3".into(),
1355 tool_name: "echo".into(),
1356 is_error: false,
1357 content: "native".into(),
1358 output: Some("native".into()),
1359 },),
1360 MessageContent::ToolResult(LanguageModelToolResult {
1361 tool_use_id: "tool_2".into(),
1362 tool_name: "test_server_echo".into(),
1363 is_error: false,
1364 content: "mcp".into(),
1365 output: Some("mcp".into()),
1366 },),
1367 ]
1368 );
1369 fake_model.end_last_completion_stream();
1370 events.collect::<Vec<_>>().await;
1371}
1372
1373#[gpui::test]
1374async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1375 let ThreadTest {
1376 model,
1377 thread,
1378 context_server_store,
1379 fs,
1380 ..
1381 } = setup(cx, TestModel::Fake).await;
1382 let fake_model = model.as_fake();
1383
1384 // Set up a profile with all tools enabled
1385 fs.insert_file(
1386 paths::settings_file(),
1387 json!({
1388 "agent": {
1389 "profiles": {
1390 "test": {
1391 "name": "Test Profile",
1392 "enable_all_context_servers": true,
1393 "tools": {
1394 EchoTool::NAME: true,
1395 DelayTool::NAME: true,
1396 WordListTool::NAME: true,
1397 ToolRequiringPermission::NAME: true,
1398 InfiniteTool::NAME: true,
1399 }
1400 },
1401 }
1402 }
1403 })
1404 .to_string()
1405 .into_bytes(),
1406 )
1407 .await;
1408 cx.run_until_parked();
1409
1410 thread.update(cx, |thread, cx| {
1411 thread.set_profile(AgentProfileId("test".into()), cx);
1412 thread.add_tool(EchoTool);
1413 thread.add_tool(DelayTool);
1414 thread.add_tool(WordListTool);
1415 thread.add_tool(ToolRequiringPermission);
1416 thread.add_tool(InfiniteTool);
1417 });
1418
1419 // Set up multiple context servers with some overlapping tool names
1420 let _server1_calls = setup_context_server(
1421 "xxx",
1422 vec![
1423 context_server::types::Tool {
1424 name: "echo".into(), // Conflicts with native EchoTool
1425 description: None,
1426 input_schema: serde_json::to_value(EchoTool::input_schema(
1427 LanguageModelToolSchemaFormat::JsonSchema,
1428 ))
1429 .unwrap(),
1430 output_schema: None,
1431 annotations: None,
1432 },
1433 context_server::types::Tool {
1434 name: "unique_tool_1".into(),
1435 description: None,
1436 input_schema: json!({"type": "object", "properties": {}}),
1437 output_schema: None,
1438 annotations: None,
1439 },
1440 ],
1441 &context_server_store,
1442 cx,
1443 );
1444
1445 let _server2_calls = setup_context_server(
1446 "yyy",
1447 vec![
1448 context_server::types::Tool {
1449 name: "echo".into(), // Also conflicts with native EchoTool
1450 description: None,
1451 input_schema: serde_json::to_value(EchoTool::input_schema(
1452 LanguageModelToolSchemaFormat::JsonSchema,
1453 ))
1454 .unwrap(),
1455 output_schema: None,
1456 annotations: None,
1457 },
1458 context_server::types::Tool {
1459 name: "unique_tool_2".into(),
1460 description: None,
1461 input_schema: json!({"type": "object", "properties": {}}),
1462 output_schema: None,
1463 annotations: None,
1464 },
1465 context_server::types::Tool {
1466 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1467 description: None,
1468 input_schema: json!({"type": "object", "properties": {}}),
1469 output_schema: None,
1470 annotations: None,
1471 },
1472 context_server::types::Tool {
1473 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1474 description: None,
1475 input_schema: json!({"type": "object", "properties": {}}),
1476 output_schema: None,
1477 annotations: None,
1478 },
1479 ],
1480 &context_server_store,
1481 cx,
1482 );
1483 let _server3_calls = setup_context_server(
1484 "zzz",
1485 vec![
1486 context_server::types::Tool {
1487 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1488 description: None,
1489 input_schema: json!({"type": "object", "properties": {}}),
1490 output_schema: None,
1491 annotations: None,
1492 },
1493 context_server::types::Tool {
1494 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1495 description: None,
1496 input_schema: json!({"type": "object", "properties": {}}),
1497 output_schema: None,
1498 annotations: None,
1499 },
1500 context_server::types::Tool {
1501 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1502 description: None,
1503 input_schema: json!({"type": "object", "properties": {}}),
1504 output_schema: None,
1505 annotations: None,
1506 },
1507 ],
1508 &context_server_store,
1509 cx,
1510 );
1511
1512 thread
1513 .update(cx, |thread, cx| {
1514 thread.send(UserMessageId::new(), ["Go"], cx)
1515 })
1516 .unwrap();
1517 cx.run_until_parked();
1518 let completion = fake_model.pending_completions().pop().unwrap();
1519 assert_eq!(
1520 tool_names_for_completion(&completion),
1521 vec![
1522 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1523 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1524 "delay",
1525 "echo",
1526 "infinite",
1527 "tool_requiring_permission",
1528 "unique_tool_1",
1529 "unique_tool_2",
1530 "word_list",
1531 "xxx_echo",
1532 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1533 "yyy_echo",
1534 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1535 ]
1536 );
1537}
1538
1539#[gpui::test]
1540#[cfg_attr(not(feature = "e2e"), ignore)]
1541async fn test_cancellation(cx: &mut TestAppContext) {
1542 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1543
1544 let mut events = thread
1545 .update(cx, |thread, cx| {
1546 thread.add_tool(InfiniteTool);
1547 thread.add_tool(EchoTool);
1548 thread.send(
1549 UserMessageId::new(),
1550 ["Call the echo tool, then call the infinite tool, then explain their output"],
1551 cx,
1552 )
1553 })
1554 .unwrap();
1555
1556 // Wait until both tools are called.
1557 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1558 let mut echo_id = None;
1559 let mut echo_completed = false;
1560 while let Some(event) = events.next().await {
1561 match event.unwrap() {
1562 ThreadEvent::ToolCall(tool_call) => {
1563 assert_eq!(tool_call.title, expected_tools.remove(0));
1564 if tool_call.title == "Echo" {
1565 echo_id = Some(tool_call.tool_call_id);
1566 }
1567 }
1568 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1569 acp::ToolCallUpdate {
1570 tool_call_id,
1571 fields:
1572 acp::ToolCallUpdateFields {
1573 status: Some(acp::ToolCallStatus::Completed),
1574 ..
1575 },
1576 ..
1577 },
1578 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1579 echo_completed = true;
1580 }
1581 _ => {}
1582 }
1583
1584 if expected_tools.is_empty() && echo_completed {
1585 break;
1586 }
1587 }
1588
1589 // Cancel the current send and ensure that the event stream is closed, even
1590 // if one of the tools is still running.
1591 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1592 let events = events.collect::<Vec<_>>().await;
1593 let last_event = events.last();
1594 assert!(
1595 matches!(
1596 last_event,
1597 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1598 ),
1599 "unexpected event {last_event:?}"
1600 );
1601
1602 // Ensure we can still send a new message after cancellation.
1603 let events = thread
1604 .update(cx, |thread, cx| {
1605 thread.send(
1606 UserMessageId::new(),
1607 ["Testing: reply with 'Hello' then stop."],
1608 cx,
1609 )
1610 })
1611 .unwrap()
1612 .collect::<Vec<_>>()
1613 .await;
1614 thread.update(cx, |thread, _cx| {
1615 let message = thread.last_message().unwrap();
1616 let agent_message = message.as_agent_message().unwrap();
1617 assert_eq!(
1618 agent_message.content,
1619 vec![AgentMessageContent::Text("Hello".to_string())]
1620 );
1621 });
1622 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1623}
1624
1625#[gpui::test]
1626async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1627 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1628 always_allow_tools(cx);
1629 let fake_model = model.as_fake();
1630
1631 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1632 let environment = Rc::new(FakeThreadEnvironment {
1633 handle: handle.clone(),
1634 });
1635
1636 let mut events = thread
1637 .update(cx, |thread, cx| {
1638 thread.add_tool(crate::TerminalTool::new(
1639 thread.project().clone(),
1640 environment,
1641 ));
1642 thread.send(UserMessageId::new(), ["run a command"], cx)
1643 })
1644 .unwrap();
1645
1646 cx.run_until_parked();
1647
1648 // Simulate the model calling the terminal tool
1649 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1650 LanguageModelToolUse {
1651 id: "terminal_tool_1".into(),
1652 name: TerminalTool::NAME.into(),
1653 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1654 input: json!({"command": "sleep 1000", "cd": "."}),
1655 is_input_complete: true,
1656 thought_signature: None,
1657 },
1658 ));
1659 fake_model.end_last_completion_stream();
1660
1661 // Wait for the terminal tool to start running
1662 wait_for_terminal_tool_started(&mut events, cx).await;
1663
1664 // Cancel the thread while the terminal is running
1665 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1666
1667 // Collect remaining events, driving the executor to let cancellation complete
1668 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1669
1670 // Verify the terminal was killed
1671 assert!(
1672 handle.was_killed(),
1673 "expected terminal handle to be killed on cancellation"
1674 );
1675
1676 // Verify we got a cancellation stop event
1677 assert_eq!(
1678 stop_events(remaining_events),
1679 vec![acp::StopReason::Cancelled],
1680 );
1681
1682 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1683 thread.update(cx, |thread, _cx| {
1684 let message = thread.last_message().unwrap();
1685 let agent_message = message.as_agent_message().unwrap();
1686
1687 let tool_use = agent_message
1688 .content
1689 .iter()
1690 .find_map(|content| match content {
1691 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1692 _ => None,
1693 })
1694 .expect("expected tool use in agent message");
1695
1696 let tool_result = agent_message
1697 .tool_results
1698 .get(&tool_use.id)
1699 .expect("expected tool result");
1700
1701 let result_text = match &tool_result.content {
1702 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
1703 _ => panic!("expected text content in tool result"),
1704 };
1705
1706 // "partial output" comes from FakeTerminalHandle's output field
1707 assert!(
1708 result_text.contains("partial output"),
1709 "expected tool result to contain terminal output, got: {result_text}"
1710 );
1711 // Match the actual format from process_content in terminal_tool.rs
1712 assert!(
1713 result_text.contains("The user stopped this command"),
1714 "expected tool result to indicate user stopped, got: {result_text}"
1715 );
1716 });
1717
1718 // Verify we can send a new message after cancellation
1719 verify_thread_recovery(&thread, &fake_model, cx).await;
1720}
1721
1722#[gpui::test]
1723async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
1724 // This test verifies that tools which properly handle cancellation via
1725 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
1726 // to cancellation and report that they were cancelled.
1727 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1728 always_allow_tools(cx);
1729 let fake_model = model.as_fake();
1730
1731 let (tool, was_cancelled) = CancellationAwareTool::new();
1732
1733 let mut events = thread
1734 .update(cx, |thread, cx| {
1735 thread.add_tool(tool);
1736 thread.send(
1737 UserMessageId::new(),
1738 ["call the cancellation aware tool"],
1739 cx,
1740 )
1741 })
1742 .unwrap();
1743
1744 cx.run_until_parked();
1745
1746 // Simulate the model calling the cancellation-aware tool
1747 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1748 LanguageModelToolUse {
1749 id: "cancellation_aware_1".into(),
1750 name: "cancellation_aware".into(),
1751 raw_input: r#"{}"#.into(),
1752 input: json!({}),
1753 is_input_complete: true,
1754 thought_signature: None,
1755 },
1756 ));
1757 fake_model.end_last_completion_stream();
1758
1759 cx.run_until_parked();
1760
1761 // Wait for the tool call to be reported
1762 let mut tool_started = false;
1763 let deadline = cx.executor().num_cpus() * 100;
1764 for _ in 0..deadline {
1765 cx.run_until_parked();
1766
1767 while let Some(Some(event)) = events.next().now_or_never() {
1768 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
1769 if tool_call.title == "Cancellation Aware Tool" {
1770 tool_started = true;
1771 break;
1772 }
1773 }
1774 }
1775
1776 if tool_started {
1777 break;
1778 }
1779
1780 cx.background_executor
1781 .timer(Duration::from_millis(10))
1782 .await;
1783 }
1784 assert!(tool_started, "expected cancellation aware tool to start");
1785
1786 // Cancel the thread and wait for it to complete
1787 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
1788
1789 // The cancel task should complete promptly because the tool handles cancellation
1790 let timeout = cx.background_executor.timer(Duration::from_secs(5));
1791 futures::select! {
1792 _ = cancel_task.fuse() => {}
1793 _ = timeout.fuse() => {
1794 panic!("cancel task timed out - tool did not respond to cancellation");
1795 }
1796 }
1797
1798 // Verify the tool detected cancellation via its flag
1799 assert!(
1800 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
1801 "tool should have detected cancellation via event_stream.cancelled_by_user()"
1802 );
1803
1804 // Collect remaining events
1805 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1806
1807 // Verify we got a cancellation stop event
1808 assert_eq!(
1809 stop_events(remaining_events),
1810 vec![acp::StopReason::Cancelled],
1811 );
1812
1813 // Verify we can send a new message after cancellation
1814 verify_thread_recovery(&thread, &fake_model, cx).await;
1815}
1816
1817/// Helper to verify thread can recover after cancellation by sending a simple message.
1818async fn verify_thread_recovery(
1819 thread: &Entity<Thread>,
1820 fake_model: &FakeLanguageModel,
1821 cx: &mut TestAppContext,
1822) {
1823 let events = thread
1824 .update(cx, |thread, cx| {
1825 thread.send(
1826 UserMessageId::new(),
1827 ["Testing: reply with 'Hello' then stop."],
1828 cx,
1829 )
1830 })
1831 .unwrap();
1832 cx.run_until_parked();
1833 fake_model.send_last_completion_stream_text_chunk("Hello");
1834 fake_model
1835 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1836 fake_model.end_last_completion_stream();
1837
1838 let events = events.collect::<Vec<_>>().await;
1839 thread.update(cx, |thread, _cx| {
1840 let message = thread.last_message().unwrap();
1841 let agent_message = message.as_agent_message().unwrap();
1842 assert_eq!(
1843 agent_message.content,
1844 vec![AgentMessageContent::Text("Hello".to_string())]
1845 );
1846 });
1847 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1848}
1849
1850/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
1851async fn wait_for_terminal_tool_started(
1852 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1853 cx: &mut TestAppContext,
1854) {
1855 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
1856 for _ in 0..deadline {
1857 cx.run_until_parked();
1858
1859 while let Some(Some(event)) = events.next().now_or_never() {
1860 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1861 update,
1862 ))) = &event
1863 {
1864 if update.fields.content.as_ref().is_some_and(|content| {
1865 content
1866 .iter()
1867 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1868 }) {
1869 return;
1870 }
1871 }
1872 }
1873
1874 cx.background_executor
1875 .timer(Duration::from_millis(10))
1876 .await;
1877 }
1878 panic!("terminal tool did not start within the expected time");
1879}
1880
1881/// Collects events until a Stop event is received, driving the executor to completion.
1882async fn collect_events_until_stop(
1883 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1884 cx: &mut TestAppContext,
1885) -> Vec<Result<ThreadEvent>> {
1886 let mut collected = Vec::new();
1887 let deadline = cx.executor().num_cpus() * 200;
1888
1889 for _ in 0..deadline {
1890 cx.executor().advance_clock(Duration::from_millis(10));
1891 cx.run_until_parked();
1892
1893 while let Some(Some(event)) = events.next().now_or_never() {
1894 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
1895 collected.push(event);
1896 if is_stop {
1897 return collected;
1898 }
1899 }
1900 }
1901 panic!(
1902 "did not receive Stop event within the expected time; collected {} events",
1903 collected.len()
1904 );
1905}
1906
1907#[gpui::test]
1908async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
1909 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1910 always_allow_tools(cx);
1911 let fake_model = model.as_fake();
1912
1913 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1914 let environment = Rc::new(FakeThreadEnvironment {
1915 handle: handle.clone(),
1916 });
1917
1918 let message_id = UserMessageId::new();
1919 let mut events = thread
1920 .update(cx, |thread, cx| {
1921 thread.add_tool(crate::TerminalTool::new(
1922 thread.project().clone(),
1923 environment,
1924 ));
1925 thread.send(message_id.clone(), ["run a command"], cx)
1926 })
1927 .unwrap();
1928
1929 cx.run_until_parked();
1930
1931 // Simulate the model calling the terminal tool
1932 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1933 LanguageModelToolUse {
1934 id: "terminal_tool_1".into(),
1935 name: TerminalTool::NAME.into(),
1936 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1937 input: json!({"command": "sleep 1000", "cd": "."}),
1938 is_input_complete: true,
1939 thought_signature: None,
1940 },
1941 ));
1942 fake_model.end_last_completion_stream();
1943
1944 // Wait for the terminal tool to start running
1945 wait_for_terminal_tool_started(&mut events, cx).await;
1946
1947 // Truncate the thread while the terminal is running
1948 thread
1949 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1950 .unwrap();
1951
1952 // Drive the executor to let cancellation complete
1953 let _ = collect_events_until_stop(&mut events, cx).await;
1954
1955 // Verify the terminal was killed
1956 assert!(
1957 handle.was_killed(),
1958 "expected terminal handle to be killed on truncate"
1959 );
1960
1961 // Verify the thread is empty after truncation
1962 thread.update(cx, |thread, _cx| {
1963 assert_eq!(
1964 thread.to_markdown(),
1965 "",
1966 "expected thread to be empty after truncating the only message"
1967 );
1968 });
1969
1970 // Verify we can send a new message after truncation
1971 verify_thread_recovery(&thread, &fake_model, cx).await;
1972}
1973
1974#[gpui::test]
1975async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
1976 // Tests that cancellation properly kills all running terminal tools when multiple are active.
1977 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1978 always_allow_tools(cx);
1979 let fake_model = model.as_fake();
1980
1981 let environment = Rc::new(MultiTerminalEnvironment::new());
1982
1983 let mut events = thread
1984 .update(cx, |thread, cx| {
1985 thread.add_tool(crate::TerminalTool::new(
1986 thread.project().clone(),
1987 environment.clone(),
1988 ));
1989 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
1990 })
1991 .unwrap();
1992
1993 cx.run_until_parked();
1994
1995 // Simulate the model calling two terminal tools
1996 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1997 LanguageModelToolUse {
1998 id: "terminal_tool_1".into(),
1999 name: TerminalTool::NAME.into(),
2000 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2001 input: json!({"command": "sleep 1000", "cd": "."}),
2002 is_input_complete: true,
2003 thought_signature: None,
2004 },
2005 ));
2006 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2007 LanguageModelToolUse {
2008 id: "terminal_tool_2".into(),
2009 name: TerminalTool::NAME.into(),
2010 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2011 input: json!({"command": "sleep 2000", "cd": "."}),
2012 is_input_complete: true,
2013 thought_signature: None,
2014 },
2015 ));
2016 fake_model.end_last_completion_stream();
2017
2018 // Wait for both terminal tools to start by counting terminal content updates
2019 let mut terminals_started = 0;
2020 let deadline = cx.executor().num_cpus() * 100;
2021 for _ in 0..deadline {
2022 cx.run_until_parked();
2023
2024 while let Some(Some(event)) = events.next().now_or_never() {
2025 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2026 update,
2027 ))) = &event
2028 {
2029 if update.fields.content.as_ref().is_some_and(|content| {
2030 content
2031 .iter()
2032 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2033 }) {
2034 terminals_started += 1;
2035 if terminals_started >= 2 {
2036 break;
2037 }
2038 }
2039 }
2040 }
2041 if terminals_started >= 2 {
2042 break;
2043 }
2044
2045 cx.background_executor
2046 .timer(Duration::from_millis(10))
2047 .await;
2048 }
2049 assert!(
2050 terminals_started >= 2,
2051 "expected 2 terminal tools to start, got {terminals_started}"
2052 );
2053
2054 // Cancel the thread while both terminals are running
2055 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2056
2057 // Collect remaining events
2058 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2059
2060 // Verify both terminal handles were killed
2061 let handles = environment.handles();
2062 assert_eq!(
2063 handles.len(),
2064 2,
2065 "expected 2 terminal handles to be created"
2066 );
2067 assert!(
2068 handles[0].was_killed(),
2069 "expected first terminal handle to be killed on cancellation"
2070 );
2071 assert!(
2072 handles[1].was_killed(),
2073 "expected second terminal handle to be killed on cancellation"
2074 );
2075
2076 // Verify we got a cancellation stop event
2077 assert_eq!(
2078 stop_events(remaining_events),
2079 vec![acp::StopReason::Cancelled],
2080 );
2081}
2082
2083#[gpui::test]
2084async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2085 // Tests that clicking the stop button on the terminal card (as opposed to the main
2086 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2087 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2088 always_allow_tools(cx);
2089 let fake_model = model.as_fake();
2090
2091 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2092 let environment = Rc::new(FakeThreadEnvironment {
2093 handle: handle.clone(),
2094 });
2095
2096 let mut events = thread
2097 .update(cx, |thread, cx| {
2098 thread.add_tool(crate::TerminalTool::new(
2099 thread.project().clone(),
2100 environment,
2101 ));
2102 thread.send(UserMessageId::new(), ["run a command"], cx)
2103 })
2104 .unwrap();
2105
2106 cx.run_until_parked();
2107
2108 // Simulate the model calling the terminal tool
2109 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2110 LanguageModelToolUse {
2111 id: "terminal_tool_1".into(),
2112 name: TerminalTool::NAME.into(),
2113 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2114 input: json!({"command": "sleep 1000", "cd": "."}),
2115 is_input_complete: true,
2116 thought_signature: None,
2117 },
2118 ));
2119 fake_model.end_last_completion_stream();
2120
2121 // Wait for the terminal tool to start running
2122 wait_for_terminal_tool_started(&mut events, cx).await;
2123
2124 // Simulate user clicking stop on the terminal card itself.
2125 // This sets the flag and signals exit (simulating what the real UI would do).
2126 handle.set_stopped_by_user(true);
2127 handle.killed.store(true, Ordering::SeqCst);
2128 handle.signal_exit();
2129
2130 // Wait for the tool to complete
2131 cx.run_until_parked();
2132
2133 // The thread continues after tool completion - simulate the model ending its turn
2134 fake_model
2135 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2136 fake_model.end_last_completion_stream();
2137
2138 // Collect remaining events
2139 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2140
2141 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2142 assert_eq!(
2143 stop_events(remaining_events),
2144 vec![acp::StopReason::EndTurn],
2145 );
2146
2147 // Verify the tool result indicates user stopped
2148 thread.update(cx, |thread, _cx| {
2149 let message = thread.last_message().unwrap();
2150 let agent_message = message.as_agent_message().unwrap();
2151
2152 let tool_use = agent_message
2153 .content
2154 .iter()
2155 .find_map(|content| match content {
2156 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2157 _ => None,
2158 })
2159 .expect("expected tool use in agent message");
2160
2161 let tool_result = agent_message
2162 .tool_results
2163 .get(&tool_use.id)
2164 .expect("expected tool result");
2165
2166 let result_text = match &tool_result.content {
2167 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2168 _ => panic!("expected text content in tool result"),
2169 };
2170
2171 assert!(
2172 result_text.contains("The user stopped this command"),
2173 "expected tool result to indicate user stopped, got: {result_text}"
2174 );
2175 });
2176}
2177
2178#[gpui::test]
2179async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2180 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2181 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2182 always_allow_tools(cx);
2183 let fake_model = model.as_fake();
2184
2185 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2186 let environment = Rc::new(FakeThreadEnvironment {
2187 handle: handle.clone(),
2188 });
2189
2190 let mut events = thread
2191 .update(cx, |thread, cx| {
2192 thread.add_tool(crate::TerminalTool::new(
2193 thread.project().clone(),
2194 environment,
2195 ));
2196 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2197 })
2198 .unwrap();
2199
2200 cx.run_until_parked();
2201
2202 // Simulate the model calling the terminal tool with a short timeout
2203 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2204 LanguageModelToolUse {
2205 id: "terminal_tool_1".into(),
2206 name: TerminalTool::NAME.into(),
2207 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2208 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2209 is_input_complete: true,
2210 thought_signature: None,
2211 },
2212 ));
2213 fake_model.end_last_completion_stream();
2214
2215 // Wait for the terminal tool to start running
2216 wait_for_terminal_tool_started(&mut events, cx).await;
2217
2218 // Advance clock past the timeout
2219 cx.executor().advance_clock(Duration::from_millis(200));
2220 cx.run_until_parked();
2221
2222 // The thread continues after tool completion - simulate the model ending its turn
2223 fake_model
2224 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2225 fake_model.end_last_completion_stream();
2226
2227 // Collect remaining events
2228 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2229
2230 // Verify the terminal was killed due to timeout
2231 assert!(
2232 handle.was_killed(),
2233 "expected terminal handle to be killed on timeout"
2234 );
2235
2236 // Verify we got an EndTurn (the tool completed, just with timeout)
2237 assert_eq!(
2238 stop_events(remaining_events),
2239 vec![acp::StopReason::EndTurn],
2240 );
2241
2242 // Verify the tool result indicates timeout, not user stopped
2243 thread.update(cx, |thread, _cx| {
2244 let message = thread.last_message().unwrap();
2245 let agent_message = message.as_agent_message().unwrap();
2246
2247 let tool_use = agent_message
2248 .content
2249 .iter()
2250 .find_map(|content| match content {
2251 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2252 _ => None,
2253 })
2254 .expect("expected tool use in agent message");
2255
2256 let tool_result = agent_message
2257 .tool_results
2258 .get(&tool_use.id)
2259 .expect("expected tool result");
2260
2261 let result_text = match &tool_result.content {
2262 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2263 _ => panic!("expected text content in tool result"),
2264 };
2265
2266 assert!(
2267 result_text.contains("timed out"),
2268 "expected tool result to indicate timeout, got: {result_text}"
2269 );
2270 assert!(
2271 !result_text.contains("The user stopped"),
2272 "tool result should not mention user stopped when it timed out, got: {result_text}"
2273 );
2274 });
2275}
2276
2277#[gpui::test]
2278async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2279 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2280 let fake_model = model.as_fake();
2281
2282 let events_1 = thread
2283 .update(cx, |thread, cx| {
2284 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2285 })
2286 .unwrap();
2287 cx.run_until_parked();
2288 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2289 cx.run_until_parked();
2290
2291 let events_2 = thread
2292 .update(cx, |thread, cx| {
2293 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2294 })
2295 .unwrap();
2296 cx.run_until_parked();
2297 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2298 fake_model
2299 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2300 fake_model.end_last_completion_stream();
2301
2302 let events_1 = events_1.collect::<Vec<_>>().await;
2303 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2304 let events_2 = events_2.collect::<Vec<_>>().await;
2305 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2306}
2307
2308#[gpui::test]
2309async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2310 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2311 let fake_model = model.as_fake();
2312
2313 let events_1 = thread
2314 .update(cx, |thread, cx| {
2315 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2316 })
2317 .unwrap();
2318 cx.run_until_parked();
2319 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2320 fake_model
2321 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2322 fake_model.end_last_completion_stream();
2323 let events_1 = events_1.collect::<Vec<_>>().await;
2324
2325 let events_2 = thread
2326 .update(cx, |thread, cx| {
2327 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2328 })
2329 .unwrap();
2330 cx.run_until_parked();
2331 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2332 fake_model
2333 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2334 fake_model.end_last_completion_stream();
2335 let events_2 = events_2.collect::<Vec<_>>().await;
2336
2337 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2338 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2339}
2340
2341#[gpui::test]
2342async fn test_refusal(cx: &mut TestAppContext) {
2343 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2344 let fake_model = model.as_fake();
2345
2346 let events = thread
2347 .update(cx, |thread, cx| {
2348 thread.send(UserMessageId::new(), ["Hello"], cx)
2349 })
2350 .unwrap();
2351 cx.run_until_parked();
2352 thread.read_with(cx, |thread, _| {
2353 assert_eq!(
2354 thread.to_markdown(),
2355 indoc! {"
2356 ## User
2357
2358 Hello
2359 "}
2360 );
2361 });
2362
2363 fake_model.send_last_completion_stream_text_chunk("Hey!");
2364 cx.run_until_parked();
2365 thread.read_with(cx, |thread, _| {
2366 assert_eq!(
2367 thread.to_markdown(),
2368 indoc! {"
2369 ## User
2370
2371 Hello
2372
2373 ## Assistant
2374
2375 Hey!
2376 "}
2377 );
2378 });
2379
2380 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2381 fake_model
2382 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2383 let events = events.collect::<Vec<_>>().await;
2384 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2385 thread.read_with(cx, |thread, _| {
2386 assert_eq!(thread.to_markdown(), "");
2387 });
2388}
2389
2390#[gpui::test]
2391async fn test_truncate_first_message(cx: &mut TestAppContext) {
2392 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2393 let fake_model = model.as_fake();
2394
2395 let message_id = UserMessageId::new();
2396 thread
2397 .update(cx, |thread, cx| {
2398 thread.send(message_id.clone(), ["Hello"], cx)
2399 })
2400 .unwrap();
2401 cx.run_until_parked();
2402 thread.read_with(cx, |thread, _| {
2403 assert_eq!(
2404 thread.to_markdown(),
2405 indoc! {"
2406 ## User
2407
2408 Hello
2409 "}
2410 );
2411 assert_eq!(thread.latest_token_usage(), None);
2412 });
2413
2414 fake_model.send_last_completion_stream_text_chunk("Hey!");
2415 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2416 language_model::TokenUsage {
2417 input_tokens: 32_000,
2418 output_tokens: 16_000,
2419 cache_creation_input_tokens: 0,
2420 cache_read_input_tokens: 0,
2421 },
2422 ));
2423 cx.run_until_parked();
2424 thread.read_with(cx, |thread, _| {
2425 assert_eq!(
2426 thread.to_markdown(),
2427 indoc! {"
2428 ## User
2429
2430 Hello
2431
2432 ## Assistant
2433
2434 Hey!
2435 "}
2436 );
2437 assert_eq!(
2438 thread.latest_token_usage(),
2439 Some(acp_thread::TokenUsage {
2440 used_tokens: 32_000 + 16_000,
2441 max_tokens: 1_000_000,
2442 input_tokens: 32_000,
2443 output_tokens: 16_000,
2444 })
2445 );
2446 });
2447
2448 thread
2449 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2450 .unwrap();
2451 cx.run_until_parked();
2452 thread.read_with(cx, |thread, _| {
2453 assert_eq!(thread.to_markdown(), "");
2454 assert_eq!(thread.latest_token_usage(), None);
2455 });
2456
2457 // Ensure we can still send a new message after truncation.
2458 thread
2459 .update(cx, |thread, cx| {
2460 thread.send(UserMessageId::new(), ["Hi"], cx)
2461 })
2462 .unwrap();
2463 thread.update(cx, |thread, _cx| {
2464 assert_eq!(
2465 thread.to_markdown(),
2466 indoc! {"
2467 ## User
2468
2469 Hi
2470 "}
2471 );
2472 });
2473 cx.run_until_parked();
2474 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2475 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2476 language_model::TokenUsage {
2477 input_tokens: 40_000,
2478 output_tokens: 20_000,
2479 cache_creation_input_tokens: 0,
2480 cache_read_input_tokens: 0,
2481 },
2482 ));
2483 cx.run_until_parked();
2484 thread.read_with(cx, |thread, _| {
2485 assert_eq!(
2486 thread.to_markdown(),
2487 indoc! {"
2488 ## User
2489
2490 Hi
2491
2492 ## Assistant
2493
2494 Ahoy!
2495 "}
2496 );
2497
2498 assert_eq!(
2499 thread.latest_token_usage(),
2500 Some(acp_thread::TokenUsage {
2501 used_tokens: 40_000 + 20_000,
2502 max_tokens: 1_000_000,
2503 input_tokens: 40_000,
2504 output_tokens: 20_000,
2505 })
2506 );
2507 });
2508}
2509
2510#[gpui::test]
2511async fn test_truncate_second_message(cx: &mut TestAppContext) {
2512 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2513 let fake_model = model.as_fake();
2514
2515 thread
2516 .update(cx, |thread, cx| {
2517 thread.send(UserMessageId::new(), ["Message 1"], cx)
2518 })
2519 .unwrap();
2520 cx.run_until_parked();
2521 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2522 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2523 language_model::TokenUsage {
2524 input_tokens: 32_000,
2525 output_tokens: 16_000,
2526 cache_creation_input_tokens: 0,
2527 cache_read_input_tokens: 0,
2528 },
2529 ));
2530 fake_model.end_last_completion_stream();
2531 cx.run_until_parked();
2532
2533 let assert_first_message_state = |cx: &mut TestAppContext| {
2534 thread.clone().read_with(cx, |thread, _| {
2535 assert_eq!(
2536 thread.to_markdown(),
2537 indoc! {"
2538 ## User
2539
2540 Message 1
2541
2542 ## Assistant
2543
2544 Message 1 response
2545 "}
2546 );
2547
2548 assert_eq!(
2549 thread.latest_token_usage(),
2550 Some(acp_thread::TokenUsage {
2551 used_tokens: 32_000 + 16_000,
2552 max_tokens: 1_000_000,
2553 input_tokens: 32_000,
2554 output_tokens: 16_000,
2555 })
2556 );
2557 });
2558 };
2559
2560 assert_first_message_state(cx);
2561
2562 let second_message_id = UserMessageId::new();
2563 thread
2564 .update(cx, |thread, cx| {
2565 thread.send(second_message_id.clone(), ["Message 2"], cx)
2566 })
2567 .unwrap();
2568 cx.run_until_parked();
2569
2570 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2571 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2572 language_model::TokenUsage {
2573 input_tokens: 40_000,
2574 output_tokens: 20_000,
2575 cache_creation_input_tokens: 0,
2576 cache_read_input_tokens: 0,
2577 },
2578 ));
2579 fake_model.end_last_completion_stream();
2580 cx.run_until_parked();
2581
2582 thread.read_with(cx, |thread, _| {
2583 assert_eq!(
2584 thread.to_markdown(),
2585 indoc! {"
2586 ## User
2587
2588 Message 1
2589
2590 ## Assistant
2591
2592 Message 1 response
2593
2594 ## User
2595
2596 Message 2
2597
2598 ## Assistant
2599
2600 Message 2 response
2601 "}
2602 );
2603
2604 assert_eq!(
2605 thread.latest_token_usage(),
2606 Some(acp_thread::TokenUsage {
2607 used_tokens: 40_000 + 20_000,
2608 max_tokens: 1_000_000,
2609 input_tokens: 40_000,
2610 output_tokens: 20_000,
2611 })
2612 );
2613 });
2614
2615 thread
2616 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2617 .unwrap();
2618 cx.run_until_parked();
2619
2620 assert_first_message_state(cx);
2621}
2622
2623#[gpui::test]
2624async fn test_title_generation(cx: &mut TestAppContext) {
2625 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2626 let fake_model = model.as_fake();
2627
2628 let summary_model = Arc::new(FakeLanguageModel::default());
2629 thread.update(cx, |thread, cx| {
2630 thread.set_summarization_model(Some(summary_model.clone()), cx)
2631 });
2632
2633 let send = thread
2634 .update(cx, |thread, cx| {
2635 thread.send(UserMessageId::new(), ["Hello"], cx)
2636 })
2637 .unwrap();
2638 cx.run_until_parked();
2639
2640 fake_model.send_last_completion_stream_text_chunk("Hey!");
2641 fake_model.end_last_completion_stream();
2642 cx.run_until_parked();
2643 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2644
2645 // Ensure the summary model has been invoked to generate a title.
2646 summary_model.send_last_completion_stream_text_chunk("Hello ");
2647 summary_model.send_last_completion_stream_text_chunk("world\nG");
2648 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2649 summary_model.end_last_completion_stream();
2650 send.collect::<Vec<_>>().await;
2651 cx.run_until_parked();
2652 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2653
2654 // Send another message, ensuring no title is generated this time.
2655 let send = thread
2656 .update(cx, |thread, cx| {
2657 thread.send(UserMessageId::new(), ["Hello again"], cx)
2658 })
2659 .unwrap();
2660 cx.run_until_parked();
2661 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2662 fake_model.end_last_completion_stream();
2663 cx.run_until_parked();
2664 assert_eq!(summary_model.pending_completions(), Vec::new());
2665 send.collect::<Vec<_>>().await;
2666 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2667}
2668
2669#[gpui::test]
2670async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2671 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2672 let fake_model = model.as_fake();
2673
2674 let _events = thread
2675 .update(cx, |thread, cx| {
2676 thread.add_tool(ToolRequiringPermission);
2677 thread.add_tool(EchoTool);
2678 thread.send(UserMessageId::new(), ["Hey!"], cx)
2679 })
2680 .unwrap();
2681 cx.run_until_parked();
2682
2683 let permission_tool_use = LanguageModelToolUse {
2684 id: "tool_id_1".into(),
2685 name: ToolRequiringPermission::NAME.into(),
2686 raw_input: "{}".into(),
2687 input: json!({}),
2688 is_input_complete: true,
2689 thought_signature: None,
2690 };
2691 let echo_tool_use = LanguageModelToolUse {
2692 id: "tool_id_2".into(),
2693 name: EchoTool::NAME.into(),
2694 raw_input: json!({"text": "test"}).to_string(),
2695 input: json!({"text": "test"}),
2696 is_input_complete: true,
2697 thought_signature: None,
2698 };
2699 fake_model.send_last_completion_stream_text_chunk("Hi!");
2700 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2701 permission_tool_use,
2702 ));
2703 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2704 echo_tool_use.clone(),
2705 ));
2706 fake_model.end_last_completion_stream();
2707 cx.run_until_parked();
2708
2709 // Ensure pending tools are skipped when building a request.
2710 let request = thread
2711 .read_with(cx, |thread, cx| {
2712 thread.build_completion_request(CompletionIntent::EditFile, cx)
2713 })
2714 .unwrap();
2715 assert_eq!(
2716 request.messages[1..],
2717 vec![
2718 LanguageModelRequestMessage {
2719 role: Role::User,
2720 content: vec!["Hey!".into()],
2721 cache: true,
2722 reasoning_details: None,
2723 },
2724 LanguageModelRequestMessage {
2725 role: Role::Assistant,
2726 content: vec![
2727 MessageContent::Text("Hi!".into()),
2728 MessageContent::ToolUse(echo_tool_use.clone())
2729 ],
2730 cache: false,
2731 reasoning_details: None,
2732 },
2733 LanguageModelRequestMessage {
2734 role: Role::User,
2735 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2736 tool_use_id: echo_tool_use.id.clone(),
2737 tool_name: echo_tool_use.name,
2738 is_error: false,
2739 content: "test".into(),
2740 output: Some("test".into())
2741 })],
2742 cache: false,
2743 reasoning_details: None,
2744 },
2745 ],
2746 );
2747}
2748
2749#[gpui::test]
2750async fn test_agent_connection(cx: &mut TestAppContext) {
2751 cx.update(settings::init);
2752 let templates = Templates::new();
2753
2754 // Initialize language model system with test provider
2755 cx.update(|cx| {
2756 gpui_tokio::init(cx);
2757
2758 let http_client = FakeHttpClient::with_404_response();
2759 let clock = Arc::new(clock::FakeSystemClock::new());
2760 let client = Client::new(clock, http_client, cx);
2761 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2762 language_model::init(client.clone(), cx);
2763 language_models::init(user_store, client.clone(), cx);
2764 LanguageModelRegistry::test(cx);
2765 });
2766 cx.executor().forbid_parking();
2767
2768 // Create a project for new_thread
2769 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2770 fake_fs.insert_tree(path!("/test"), json!({})).await;
2771 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2772 let cwd = Path::new("/test");
2773 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2774
2775 // Create agent and connection
2776 let agent = NativeAgent::new(
2777 project.clone(),
2778 thread_store,
2779 templates.clone(),
2780 None,
2781 fake_fs.clone(),
2782 &mut cx.to_async(),
2783 )
2784 .await
2785 .unwrap();
2786 let connection = NativeAgentConnection(agent.clone());
2787
2788 // Create a thread using new_thread
2789 let connection_rc = Rc::new(connection.clone());
2790 let acp_thread = cx
2791 .update(|cx| connection_rc.new_thread(project, cwd, cx))
2792 .await
2793 .expect("new_thread should succeed");
2794
2795 // Get the session_id from the AcpThread
2796 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2797
2798 // Test model_selector returns Some
2799 let selector_opt = connection.model_selector(&session_id);
2800 assert!(
2801 selector_opt.is_some(),
2802 "agent should always support ModelSelector"
2803 );
2804 let selector = selector_opt.unwrap();
2805
2806 // Test list_models
2807 let listed_models = cx
2808 .update(|cx| selector.list_models(cx))
2809 .await
2810 .expect("list_models should succeed");
2811 let AgentModelList::Grouped(listed_models) = listed_models else {
2812 panic!("Unexpected model list type");
2813 };
2814 assert!(!listed_models.is_empty(), "should have at least one model");
2815 assert_eq!(
2816 listed_models[&AgentModelGroupName("Fake".into())][0]
2817 .id
2818 .0
2819 .as_ref(),
2820 "fake/fake"
2821 );
2822
2823 // Test selected_model returns the default
2824 let model = cx
2825 .update(|cx| selector.selected_model(cx))
2826 .await
2827 .expect("selected_model should succeed");
2828 let model = cx
2829 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2830 .unwrap();
2831 let model = model.as_fake();
2832 assert_eq!(model.id().0, "fake", "should return default model");
2833
2834 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2835 cx.run_until_parked();
2836 model.send_last_completion_stream_text_chunk("def");
2837 cx.run_until_parked();
2838 acp_thread.read_with(cx, |thread, cx| {
2839 assert_eq!(
2840 thread.to_markdown(cx),
2841 indoc! {"
2842 ## User
2843
2844 abc
2845
2846 ## Assistant
2847
2848 def
2849
2850 "}
2851 )
2852 });
2853
2854 // Test cancel
2855 cx.update(|cx| connection.cancel(&session_id, cx));
2856 request.await.expect("prompt should fail gracefully");
2857
2858 // Ensure that dropping the ACP thread causes the native thread to be
2859 // dropped as well.
2860 cx.update(|_| drop(acp_thread));
2861 let result = cx
2862 .update(|cx| {
2863 connection.prompt(
2864 Some(acp_thread::UserMessageId::new()),
2865 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2866 cx,
2867 )
2868 })
2869 .await;
2870 assert_eq!(
2871 result.as_ref().unwrap_err().to_string(),
2872 "Session not found",
2873 "unexpected result: {:?}",
2874 result
2875 );
2876}
2877
2878#[gpui::test]
2879async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2880 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2881 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2882 let fake_model = model.as_fake();
2883
2884 let mut events = thread
2885 .update(cx, |thread, cx| {
2886 thread.send(UserMessageId::new(), ["Think"], cx)
2887 })
2888 .unwrap();
2889 cx.run_until_parked();
2890
2891 // Simulate streaming partial input.
2892 let input = json!({});
2893 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2894 LanguageModelToolUse {
2895 id: "1".into(),
2896 name: ThinkingTool::NAME.into(),
2897 raw_input: input.to_string(),
2898 input,
2899 is_input_complete: false,
2900 thought_signature: None,
2901 },
2902 ));
2903
2904 // Input streaming completed
2905 let input = json!({ "content": "Thinking hard!" });
2906 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2907 LanguageModelToolUse {
2908 id: "1".into(),
2909 name: "thinking".into(),
2910 raw_input: input.to_string(),
2911 input,
2912 is_input_complete: true,
2913 thought_signature: None,
2914 },
2915 ));
2916 fake_model.end_last_completion_stream();
2917 cx.run_until_parked();
2918
2919 let tool_call = expect_tool_call(&mut events).await;
2920 assert_eq!(
2921 tool_call,
2922 acp::ToolCall::new("1", "Thinking")
2923 .kind(acp::ToolKind::Think)
2924 .raw_input(json!({}))
2925 .meta(acp::Meta::from_iter([(
2926 "tool_name".into(),
2927 "thinking".into()
2928 )]))
2929 );
2930 let update = expect_tool_call_update_fields(&mut events).await;
2931 assert_eq!(
2932 update,
2933 acp::ToolCallUpdate::new(
2934 "1",
2935 acp::ToolCallUpdateFields::new()
2936 .title("Thinking")
2937 .kind(acp::ToolKind::Think)
2938 .raw_input(json!({ "content": "Thinking hard!"}))
2939 )
2940 );
2941 let update = expect_tool_call_update_fields(&mut events).await;
2942 assert_eq!(
2943 update,
2944 acp::ToolCallUpdate::new(
2945 "1",
2946 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
2947 )
2948 );
2949 let update = expect_tool_call_update_fields(&mut events).await;
2950 assert_eq!(
2951 update,
2952 acp::ToolCallUpdate::new(
2953 "1",
2954 acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
2955 )
2956 );
2957 let update = expect_tool_call_update_fields(&mut events).await;
2958 assert_eq!(
2959 update,
2960 acp::ToolCallUpdate::new(
2961 "1",
2962 acp::ToolCallUpdateFields::new()
2963 .status(acp::ToolCallStatus::Completed)
2964 .raw_output("Finished thinking.")
2965 )
2966 );
2967}
2968
2969#[gpui::test]
2970async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2971 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2972 let fake_model = model.as_fake();
2973
2974 let mut events = thread
2975 .update(cx, |thread, cx| {
2976 thread.send(UserMessageId::new(), ["Hello!"], cx)
2977 })
2978 .unwrap();
2979 cx.run_until_parked();
2980
2981 fake_model.send_last_completion_stream_text_chunk("Hey!");
2982 fake_model.end_last_completion_stream();
2983
2984 let mut retry_events = Vec::new();
2985 while let Some(Ok(event)) = events.next().await {
2986 match event {
2987 ThreadEvent::Retry(retry_status) => {
2988 retry_events.push(retry_status);
2989 }
2990 ThreadEvent::Stop(..) => break,
2991 _ => {}
2992 }
2993 }
2994
2995 assert_eq!(retry_events.len(), 0);
2996 thread.read_with(cx, |thread, _cx| {
2997 assert_eq!(
2998 thread.to_markdown(),
2999 indoc! {"
3000 ## User
3001
3002 Hello!
3003
3004 ## Assistant
3005
3006 Hey!
3007 "}
3008 )
3009 });
3010}
3011
3012#[gpui::test]
3013async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3014 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3015 let fake_model = model.as_fake();
3016
3017 let mut events = thread
3018 .update(cx, |thread, cx| {
3019 thread.send(UserMessageId::new(), ["Hello!"], cx)
3020 })
3021 .unwrap();
3022 cx.run_until_parked();
3023
3024 fake_model.send_last_completion_stream_text_chunk("Hey,");
3025 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3026 provider: LanguageModelProviderName::new("Anthropic"),
3027 retry_after: Some(Duration::from_secs(3)),
3028 });
3029 fake_model.end_last_completion_stream();
3030
3031 cx.executor().advance_clock(Duration::from_secs(3));
3032 cx.run_until_parked();
3033
3034 fake_model.send_last_completion_stream_text_chunk("there!");
3035 fake_model.end_last_completion_stream();
3036 cx.run_until_parked();
3037
3038 let mut retry_events = Vec::new();
3039 while let Some(Ok(event)) = events.next().await {
3040 match event {
3041 ThreadEvent::Retry(retry_status) => {
3042 retry_events.push(retry_status);
3043 }
3044 ThreadEvent::Stop(..) => break,
3045 _ => {}
3046 }
3047 }
3048
3049 assert_eq!(retry_events.len(), 1);
3050 assert!(matches!(
3051 retry_events[0],
3052 acp_thread::RetryStatus { attempt: 1, .. }
3053 ));
3054 thread.read_with(cx, |thread, _cx| {
3055 assert_eq!(
3056 thread.to_markdown(),
3057 indoc! {"
3058 ## User
3059
3060 Hello!
3061
3062 ## Assistant
3063
3064 Hey,
3065
3066 [resume]
3067
3068 ## Assistant
3069
3070 there!
3071 "}
3072 )
3073 });
3074}
3075
3076#[gpui::test]
3077async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3078 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3079 let fake_model = model.as_fake();
3080
3081 let events = thread
3082 .update(cx, |thread, cx| {
3083 thread.add_tool(EchoTool);
3084 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3085 })
3086 .unwrap();
3087 cx.run_until_parked();
3088
3089 let tool_use_1 = LanguageModelToolUse {
3090 id: "tool_1".into(),
3091 name: EchoTool::NAME.into(),
3092 raw_input: json!({"text": "test"}).to_string(),
3093 input: json!({"text": "test"}),
3094 is_input_complete: true,
3095 thought_signature: None,
3096 };
3097 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3098 tool_use_1.clone(),
3099 ));
3100 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3101 provider: LanguageModelProviderName::new("Anthropic"),
3102 retry_after: Some(Duration::from_secs(3)),
3103 });
3104 fake_model.end_last_completion_stream();
3105
3106 cx.executor().advance_clock(Duration::from_secs(3));
3107 let completion = fake_model.pending_completions().pop().unwrap();
3108 assert_eq!(
3109 completion.messages[1..],
3110 vec![
3111 LanguageModelRequestMessage {
3112 role: Role::User,
3113 content: vec!["Call the echo tool!".into()],
3114 cache: false,
3115 reasoning_details: None,
3116 },
3117 LanguageModelRequestMessage {
3118 role: Role::Assistant,
3119 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3120 cache: false,
3121 reasoning_details: None,
3122 },
3123 LanguageModelRequestMessage {
3124 role: Role::User,
3125 content: vec![language_model::MessageContent::ToolResult(
3126 LanguageModelToolResult {
3127 tool_use_id: tool_use_1.id.clone(),
3128 tool_name: tool_use_1.name.clone(),
3129 is_error: false,
3130 content: "test".into(),
3131 output: Some("test".into())
3132 }
3133 )],
3134 cache: true,
3135 reasoning_details: None,
3136 },
3137 ]
3138 );
3139
3140 fake_model.send_last_completion_stream_text_chunk("Done");
3141 fake_model.end_last_completion_stream();
3142 cx.run_until_parked();
3143 events.collect::<Vec<_>>().await;
3144 thread.read_with(cx, |thread, _cx| {
3145 assert_eq!(
3146 thread.last_message(),
3147 Some(Message::Agent(AgentMessage {
3148 content: vec![AgentMessageContent::Text("Done".into())],
3149 tool_results: IndexMap::default(),
3150 reasoning_details: None,
3151 }))
3152 );
3153 })
3154}
3155
3156#[gpui::test]
3157async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3158 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3159 let fake_model = model.as_fake();
3160
3161 let mut events = thread
3162 .update(cx, |thread, cx| {
3163 thread.send(UserMessageId::new(), ["Hello!"], cx)
3164 })
3165 .unwrap();
3166 cx.run_until_parked();
3167
3168 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3169 fake_model.send_last_completion_stream_error(
3170 LanguageModelCompletionError::ServerOverloaded {
3171 provider: LanguageModelProviderName::new("Anthropic"),
3172 retry_after: Some(Duration::from_secs(3)),
3173 },
3174 );
3175 fake_model.end_last_completion_stream();
3176 cx.executor().advance_clock(Duration::from_secs(3));
3177 cx.run_until_parked();
3178 }
3179
3180 let mut errors = Vec::new();
3181 let mut retry_events = Vec::new();
3182 while let Some(event) = events.next().await {
3183 match event {
3184 Ok(ThreadEvent::Retry(retry_status)) => {
3185 retry_events.push(retry_status);
3186 }
3187 Ok(ThreadEvent::Stop(..)) => break,
3188 Err(error) => errors.push(error),
3189 _ => {}
3190 }
3191 }
3192
3193 assert_eq!(
3194 retry_events.len(),
3195 crate::thread::MAX_RETRY_ATTEMPTS as usize
3196 );
3197 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3198 assert_eq!(retry_events[i].attempt, i + 1);
3199 }
3200 assert_eq!(errors.len(), 1);
3201 let error = errors[0]
3202 .downcast_ref::<LanguageModelCompletionError>()
3203 .unwrap();
3204 assert!(matches!(
3205 error,
3206 LanguageModelCompletionError::ServerOverloaded { .. }
3207 ));
3208}
3209
3210/// Filters out the stop events for asserting against in tests
3211fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3212 result_events
3213 .into_iter()
3214 .filter_map(|event| match event.unwrap() {
3215 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3216 _ => None,
3217 })
3218 .collect()
3219}
3220
3221struct ThreadTest {
3222 model: Arc<dyn LanguageModel>,
3223 thread: Entity<Thread>,
3224 project_context: Entity<ProjectContext>,
3225 context_server_store: Entity<ContextServerStore>,
3226 fs: Arc<FakeFs>,
3227}
3228
3229enum TestModel {
3230 Sonnet4,
3231 Fake,
3232}
3233
3234impl TestModel {
3235 fn id(&self) -> LanguageModelId {
3236 match self {
3237 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3238 TestModel::Fake => unreachable!(),
3239 }
3240 }
3241}
3242
3243async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3244 cx.executor().allow_parking();
3245
3246 let fs = FakeFs::new(cx.background_executor.clone());
3247 fs.create_dir(paths::settings_file().parent().unwrap())
3248 .await
3249 .unwrap();
3250 fs.insert_file(
3251 paths::settings_file(),
3252 json!({
3253 "agent": {
3254 "default_profile": "test-profile",
3255 "profiles": {
3256 "test-profile": {
3257 "name": "Test Profile",
3258 "tools": {
3259 EchoTool::NAME: true,
3260 DelayTool::NAME: true,
3261 WordListTool::NAME: true,
3262 ToolRequiringPermission::NAME: true,
3263 InfiniteTool::NAME: true,
3264 CancellationAwareTool::NAME: true,
3265 ThinkingTool::NAME: true,
3266 (TerminalTool::NAME): true,
3267 }
3268 }
3269 }
3270 }
3271 })
3272 .to_string()
3273 .into_bytes(),
3274 )
3275 .await;
3276
3277 cx.update(|cx| {
3278 settings::init(cx);
3279
3280 match model {
3281 TestModel::Fake => {}
3282 TestModel::Sonnet4 => {
3283 gpui_tokio::init(cx);
3284 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3285 cx.set_http_client(Arc::new(http_client));
3286 let client = Client::production(cx);
3287 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3288 language_model::init(client.clone(), cx);
3289 language_models::init(user_store, client.clone(), cx);
3290 }
3291 };
3292
3293 watch_settings(fs.clone(), cx);
3294 });
3295
3296 let templates = Templates::new();
3297
3298 fs.insert_tree(path!("/test"), json!({})).await;
3299 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3300
3301 let model = cx
3302 .update(|cx| {
3303 if let TestModel::Fake = model {
3304 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3305 } else {
3306 let model_id = model.id();
3307 let models = LanguageModelRegistry::read_global(cx);
3308 let model = models
3309 .available_models(cx)
3310 .find(|model| model.id() == model_id)
3311 .unwrap();
3312
3313 let provider = models.provider(&model.provider_id()).unwrap();
3314 let authenticated = provider.authenticate(cx);
3315
3316 cx.spawn(async move |_cx| {
3317 authenticated.await.unwrap();
3318 model
3319 })
3320 }
3321 })
3322 .await;
3323
3324 let project_context = cx.new(|_cx| ProjectContext::default());
3325 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3326 let context_server_registry =
3327 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3328 let thread = cx.new(|cx| {
3329 Thread::new(
3330 project,
3331 project_context.clone(),
3332 context_server_registry,
3333 templates,
3334 Some(model.clone()),
3335 cx,
3336 )
3337 });
3338 ThreadTest {
3339 model,
3340 thread,
3341 project_context,
3342 context_server_store,
3343 fs,
3344 }
3345}
3346
3347#[cfg(test)]
3348#[ctor::ctor]
3349fn init_logger() {
3350 if std::env::var("RUST_LOG").is_ok() {
3351 env_logger::init();
3352 }
3353}
3354
3355fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3356 let fs = fs.clone();
3357 cx.spawn({
3358 async move |cx| {
3359 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3360 cx.background_executor(),
3361 fs,
3362 paths::settings_file().clone(),
3363 );
3364 let _watcher_task = watcher_task;
3365
3366 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3367 cx.update(|cx| {
3368 SettingsStore::update_global(cx, |settings, cx| {
3369 settings.set_user_settings(&new_settings_content, cx)
3370 })
3371 })
3372 .ok();
3373 }
3374 }
3375 })
3376 .detach();
3377}
3378
3379fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3380 completion
3381 .tools
3382 .iter()
3383 .map(|tool| tool.name.clone())
3384 .collect()
3385}
3386
3387fn setup_context_server(
3388 name: &'static str,
3389 tools: Vec<context_server::types::Tool>,
3390 context_server_store: &Entity<ContextServerStore>,
3391 cx: &mut TestAppContext,
3392) -> mpsc::UnboundedReceiver<(
3393 context_server::types::CallToolParams,
3394 oneshot::Sender<context_server::types::CallToolResponse>,
3395)> {
3396 cx.update(|cx| {
3397 let mut settings = ProjectSettings::get_global(cx).clone();
3398 settings.context_servers.insert(
3399 name.into(),
3400 project::project_settings::ContextServerSettings::Stdio {
3401 enabled: true,
3402 remote: false,
3403 command: ContextServerCommand {
3404 path: "somebinary".into(),
3405 args: Vec::new(),
3406 env: None,
3407 timeout: None,
3408 },
3409 },
3410 );
3411 ProjectSettings::override_global(settings, cx);
3412 });
3413
3414 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3415 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3416 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3417 context_server::types::InitializeResponse {
3418 protocol_version: context_server::types::ProtocolVersion(
3419 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3420 ),
3421 server_info: context_server::types::Implementation {
3422 name: name.into(),
3423 version: "1.0.0".to_string(),
3424 },
3425 capabilities: context_server::types::ServerCapabilities {
3426 tools: Some(context_server::types::ToolsCapabilities {
3427 list_changed: Some(true),
3428 }),
3429 ..Default::default()
3430 },
3431 meta: None,
3432 }
3433 })
3434 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3435 let tools = tools.clone();
3436 async move {
3437 context_server::types::ListToolsResponse {
3438 tools,
3439 next_cursor: None,
3440 meta: None,
3441 }
3442 }
3443 })
3444 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3445 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3446 async move {
3447 let (response_tx, response_rx) = oneshot::channel();
3448 mcp_tool_calls_tx
3449 .unbounded_send((params, response_tx))
3450 .unwrap();
3451 response_rx.await.unwrap()
3452 }
3453 });
3454 context_server_store.update(cx, |store, cx| {
3455 store.start_server(
3456 Arc::new(ContextServer::new(
3457 ContextServerId(name.into()),
3458 Arc::new(fake_transport),
3459 )),
3460 cx,
3461 );
3462 });
3463 cx.run_until_parked();
3464 mcp_tool_calls_rx
3465}
3466
3467#[gpui::test]
3468async fn test_tokens_before_message(cx: &mut TestAppContext) {
3469 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3470 let fake_model = model.as_fake();
3471
3472 // First message
3473 let message_1_id = UserMessageId::new();
3474 thread
3475 .update(cx, |thread, cx| {
3476 thread.send(message_1_id.clone(), ["First message"], cx)
3477 })
3478 .unwrap();
3479 cx.run_until_parked();
3480
3481 // Before any response, tokens_before_message should return None for first message
3482 thread.read_with(cx, |thread, _| {
3483 assert_eq!(
3484 thread.tokens_before_message(&message_1_id),
3485 None,
3486 "First message should have no tokens before it"
3487 );
3488 });
3489
3490 // Complete first message with usage
3491 fake_model.send_last_completion_stream_text_chunk("Response 1");
3492 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3493 language_model::TokenUsage {
3494 input_tokens: 100,
3495 output_tokens: 50,
3496 cache_creation_input_tokens: 0,
3497 cache_read_input_tokens: 0,
3498 },
3499 ));
3500 fake_model.end_last_completion_stream();
3501 cx.run_until_parked();
3502
3503 // First message still has no tokens before it
3504 thread.read_with(cx, |thread, _| {
3505 assert_eq!(
3506 thread.tokens_before_message(&message_1_id),
3507 None,
3508 "First message should still have no tokens before it after response"
3509 );
3510 });
3511
3512 // Second message
3513 let message_2_id = UserMessageId::new();
3514 thread
3515 .update(cx, |thread, cx| {
3516 thread.send(message_2_id.clone(), ["Second message"], cx)
3517 })
3518 .unwrap();
3519 cx.run_until_parked();
3520
3521 // Second message should have first message's input tokens before it
3522 thread.read_with(cx, |thread, _| {
3523 assert_eq!(
3524 thread.tokens_before_message(&message_2_id),
3525 Some(100),
3526 "Second message should have 100 tokens before it (from first request)"
3527 );
3528 });
3529
3530 // Complete second message
3531 fake_model.send_last_completion_stream_text_chunk("Response 2");
3532 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3533 language_model::TokenUsage {
3534 input_tokens: 250, // Total for this request (includes previous context)
3535 output_tokens: 75,
3536 cache_creation_input_tokens: 0,
3537 cache_read_input_tokens: 0,
3538 },
3539 ));
3540 fake_model.end_last_completion_stream();
3541 cx.run_until_parked();
3542
3543 // Third message
3544 let message_3_id = UserMessageId::new();
3545 thread
3546 .update(cx, |thread, cx| {
3547 thread.send(message_3_id.clone(), ["Third message"], cx)
3548 })
3549 .unwrap();
3550 cx.run_until_parked();
3551
3552 // Third message should have second message's input tokens (250) before it
3553 thread.read_with(cx, |thread, _| {
3554 assert_eq!(
3555 thread.tokens_before_message(&message_3_id),
3556 Some(250),
3557 "Third message should have 250 tokens before it (from second request)"
3558 );
3559 // Second message should still have 100
3560 assert_eq!(
3561 thread.tokens_before_message(&message_2_id),
3562 Some(100),
3563 "Second message should still have 100 tokens before it"
3564 );
3565 // First message still has none
3566 assert_eq!(
3567 thread.tokens_before_message(&message_1_id),
3568 None,
3569 "First message should still have no tokens before it"
3570 );
3571 });
3572}
3573
3574#[gpui::test]
3575async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3576 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3577 let fake_model = model.as_fake();
3578
3579 // Set up three messages with responses
3580 let message_1_id = UserMessageId::new();
3581 thread
3582 .update(cx, |thread, cx| {
3583 thread.send(message_1_id.clone(), ["Message 1"], cx)
3584 })
3585 .unwrap();
3586 cx.run_until_parked();
3587 fake_model.send_last_completion_stream_text_chunk("Response 1");
3588 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3589 language_model::TokenUsage {
3590 input_tokens: 100,
3591 output_tokens: 50,
3592 cache_creation_input_tokens: 0,
3593 cache_read_input_tokens: 0,
3594 },
3595 ));
3596 fake_model.end_last_completion_stream();
3597 cx.run_until_parked();
3598
3599 let message_2_id = UserMessageId::new();
3600 thread
3601 .update(cx, |thread, cx| {
3602 thread.send(message_2_id.clone(), ["Message 2"], cx)
3603 })
3604 .unwrap();
3605 cx.run_until_parked();
3606 fake_model.send_last_completion_stream_text_chunk("Response 2");
3607 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3608 language_model::TokenUsage {
3609 input_tokens: 250,
3610 output_tokens: 75,
3611 cache_creation_input_tokens: 0,
3612 cache_read_input_tokens: 0,
3613 },
3614 ));
3615 fake_model.end_last_completion_stream();
3616 cx.run_until_parked();
3617
3618 // Verify initial state
3619 thread.read_with(cx, |thread, _| {
3620 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3621 });
3622
3623 // Truncate at message 2 (removes message 2 and everything after)
3624 thread
3625 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3626 .unwrap();
3627 cx.run_until_parked();
3628
3629 // After truncation, message_2_id no longer exists, so lookup should return None
3630 thread.read_with(cx, |thread, _| {
3631 assert_eq!(
3632 thread.tokens_before_message(&message_2_id),
3633 None,
3634 "After truncation, message 2 no longer exists"
3635 );
3636 // Message 1 still exists but has no tokens before it
3637 assert_eq!(
3638 thread.tokens_before_message(&message_1_id),
3639 None,
3640 "First message still has no tokens before it"
3641 );
3642 });
3643}
3644
3645#[gpui::test]
3646async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3647 init_test(cx);
3648
3649 let fs = FakeFs::new(cx.executor());
3650 fs.insert_tree("/root", json!({})).await;
3651 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3652
3653 // Test 1: Deny rule blocks command
3654 {
3655 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3656 let environment = Rc::new(FakeThreadEnvironment {
3657 handle: handle.clone(),
3658 });
3659
3660 cx.update(|cx| {
3661 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3662 settings.tool_permissions.tools.insert(
3663 TerminalTool::NAME.into(),
3664 agent_settings::ToolRules {
3665 default_mode: settings::ToolPermissionMode::Confirm,
3666 always_allow: vec![],
3667 always_deny: vec![
3668 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3669 ],
3670 always_confirm: vec![],
3671 invalid_patterns: vec![],
3672 },
3673 );
3674 agent_settings::AgentSettings::override_global(settings, cx);
3675 });
3676
3677 #[allow(clippy::arc_with_non_send_sync)]
3678 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3679 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3680
3681 let task = cx.update(|cx| {
3682 tool.run(
3683 crate::TerminalToolInput {
3684 command: "rm -rf /".to_string(),
3685 cd: ".".to_string(),
3686 timeout_ms: None,
3687 },
3688 event_stream,
3689 cx,
3690 )
3691 });
3692
3693 let result = task.await;
3694 assert!(
3695 result.is_err(),
3696 "expected command to be blocked by deny rule"
3697 );
3698 let err_msg = result.unwrap_err().to_string().to_lowercase();
3699 assert!(
3700 err_msg.contains("blocked"),
3701 "error should mention the command was blocked"
3702 );
3703 }
3704
3705 // Test 2: Allow rule skips confirmation (and overrides default_mode: Deny)
3706 {
3707 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3708 let environment = Rc::new(FakeThreadEnvironment {
3709 handle: handle.clone(),
3710 });
3711
3712 cx.update(|cx| {
3713 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3714 settings.always_allow_tool_actions = false;
3715 settings.tool_permissions.tools.insert(
3716 TerminalTool::NAME.into(),
3717 agent_settings::ToolRules {
3718 default_mode: settings::ToolPermissionMode::Deny,
3719 always_allow: vec![
3720 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
3721 ],
3722 always_deny: vec![],
3723 always_confirm: vec![],
3724 invalid_patterns: vec![],
3725 },
3726 );
3727 agent_settings::AgentSettings::override_global(settings, cx);
3728 });
3729
3730 #[allow(clippy::arc_with_non_send_sync)]
3731 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3732 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3733
3734 let task = cx.update(|cx| {
3735 tool.run(
3736 crate::TerminalToolInput {
3737 command: "echo hello".to_string(),
3738 cd: ".".to_string(),
3739 timeout_ms: None,
3740 },
3741 event_stream,
3742 cx,
3743 )
3744 });
3745
3746 let update = rx.expect_update_fields().await;
3747 assert!(
3748 update.content.iter().any(|blocks| {
3749 blocks
3750 .iter()
3751 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
3752 }),
3753 "expected terminal content (allow rule should skip confirmation and override default deny)"
3754 );
3755
3756 let result = task.await;
3757 assert!(
3758 result.is_ok(),
3759 "expected command to succeed without confirmation"
3760 );
3761 }
3762
3763 // Test 3: always_allow_tool_actions=true overrides always_confirm patterns
3764 {
3765 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3766 let environment = Rc::new(FakeThreadEnvironment {
3767 handle: handle.clone(),
3768 });
3769
3770 cx.update(|cx| {
3771 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3772 settings.always_allow_tool_actions = true;
3773 settings.tool_permissions.tools.insert(
3774 TerminalTool::NAME.into(),
3775 agent_settings::ToolRules {
3776 default_mode: settings::ToolPermissionMode::Allow,
3777 always_allow: vec![],
3778 always_deny: vec![],
3779 always_confirm: vec![
3780 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
3781 ],
3782 invalid_patterns: vec![],
3783 },
3784 );
3785 agent_settings::AgentSettings::override_global(settings, cx);
3786 });
3787
3788 #[allow(clippy::arc_with_non_send_sync)]
3789 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3790 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3791
3792 let task = cx.update(|cx| {
3793 tool.run(
3794 crate::TerminalToolInput {
3795 command: "sudo rm file".to_string(),
3796 cd: ".".to_string(),
3797 timeout_ms: None,
3798 },
3799 event_stream,
3800 cx,
3801 )
3802 });
3803
3804 // With always_allow_tool_actions=true, confirm patterns are overridden
3805 task.await
3806 .expect("command should be allowed with always_allow_tool_actions=true");
3807 }
3808
3809 // Test 4: always_allow_tool_actions=true overrides default_mode: Deny
3810 {
3811 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3812 let environment = Rc::new(FakeThreadEnvironment {
3813 handle: handle.clone(),
3814 });
3815
3816 cx.update(|cx| {
3817 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3818 settings.always_allow_tool_actions = true;
3819 settings.tool_permissions.tools.insert(
3820 TerminalTool::NAME.into(),
3821 agent_settings::ToolRules {
3822 default_mode: settings::ToolPermissionMode::Deny,
3823 always_allow: vec![],
3824 always_deny: vec![],
3825 always_confirm: vec![],
3826 invalid_patterns: vec![],
3827 },
3828 );
3829 agent_settings::AgentSettings::override_global(settings, cx);
3830 });
3831
3832 #[allow(clippy::arc_with_non_send_sync)]
3833 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3834 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3835
3836 let task = cx.update(|cx| {
3837 tool.run(
3838 crate::TerminalToolInput {
3839 command: "echo hello".to_string(),
3840 cd: ".".to_string(),
3841 timeout_ms: None,
3842 },
3843 event_stream,
3844 cx,
3845 )
3846 });
3847
3848 // With always_allow_tool_actions=true, even default_mode: Deny is overridden
3849 task.await
3850 .expect("command should be allowed with always_allow_tool_actions=true");
3851 }
3852}
3853
3854#[gpui::test]
3855async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
3856 init_test(cx);
3857
3858 cx.update(|cx| {
3859 cx.update_flags(true, vec!["subagents".to_string()]);
3860 });
3861
3862 let fs = FakeFs::new(cx.executor());
3863 fs.insert_tree(path!("/test"), json!({})).await;
3864 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3865 let project_context = cx.new(|_cx| ProjectContext::default());
3866 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3867 let context_server_registry =
3868 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3869 let model = Arc::new(FakeLanguageModel::default());
3870
3871 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3872 let environment = Rc::new(FakeThreadEnvironment { handle });
3873
3874 let thread = cx.new(|cx| {
3875 let mut thread = Thread::new(
3876 project.clone(),
3877 project_context,
3878 context_server_registry,
3879 Templates::new(),
3880 Some(model),
3881 cx,
3882 );
3883 thread.add_default_tools(environment, cx);
3884 thread
3885 });
3886
3887 thread.read_with(cx, |thread, _| {
3888 assert!(
3889 thread.has_registered_tool(SubagentTool::NAME),
3890 "subagent tool should be present when feature flag is enabled"
3891 );
3892 });
3893}
3894
3895#[gpui::test]
3896async fn test_subagent_thread_inherits_parent_model(cx: &mut TestAppContext) {
3897 init_test(cx);
3898
3899 cx.update(|cx| {
3900 cx.update_flags(true, vec!["subagents".to_string()]);
3901 });
3902
3903 let fs = FakeFs::new(cx.executor());
3904 fs.insert_tree(path!("/test"), json!({})).await;
3905 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3906 let project_context = cx.new(|_cx| ProjectContext::default());
3907 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3908 let context_server_registry =
3909 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3910 let model = Arc::new(FakeLanguageModel::default());
3911
3912 let subagent_context = SubagentContext {
3913 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3914 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3915 depth: 1,
3916 summary_prompt: "Summarize".to_string(),
3917 context_low_prompt: "Context low".to_string(),
3918 };
3919
3920 let subagent = cx.new(|cx| {
3921 Thread::new_subagent(
3922 project.clone(),
3923 project_context,
3924 context_server_registry,
3925 Templates::new(),
3926 model.clone(),
3927 subagent_context,
3928 std::collections::BTreeMap::new(),
3929 cx,
3930 )
3931 });
3932
3933 subagent.read_with(cx, |thread, _| {
3934 assert!(thread.is_subagent());
3935 assert_eq!(thread.depth(), 1);
3936 assert!(thread.model().is_some());
3937 });
3938}
3939
3940#[gpui::test]
3941async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
3942 init_test(cx);
3943
3944 cx.update(|cx| {
3945 cx.update_flags(true, vec!["subagents".to_string()]);
3946 });
3947
3948 let fs = FakeFs::new(cx.executor());
3949 fs.insert_tree(path!("/test"), json!({})).await;
3950 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3951 let project_context = cx.new(|_cx| ProjectContext::default());
3952 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3953 let context_server_registry =
3954 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3955 let model = Arc::new(FakeLanguageModel::default());
3956
3957 let subagent_context = SubagentContext {
3958 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3959 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3960 depth: MAX_SUBAGENT_DEPTH,
3961 summary_prompt: "Summarize".to_string(),
3962 context_low_prompt: "Context low".to_string(),
3963 };
3964
3965 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3966 let environment = Rc::new(FakeThreadEnvironment { handle });
3967
3968 let deep_subagent = cx.new(|cx| {
3969 let mut thread = Thread::new_subagent(
3970 project.clone(),
3971 project_context,
3972 context_server_registry,
3973 Templates::new(),
3974 model.clone(),
3975 subagent_context,
3976 std::collections::BTreeMap::new(),
3977 cx,
3978 );
3979 thread.add_default_tools(environment, cx);
3980 thread
3981 });
3982
3983 deep_subagent.read_with(cx, |thread, _| {
3984 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
3985 assert!(
3986 !thread.has_registered_tool(SubagentTool::NAME),
3987 "subagent tool should not be present at max depth"
3988 );
3989 });
3990}
3991
3992#[gpui::test]
3993async fn test_subagent_receives_task_prompt(cx: &mut TestAppContext) {
3994 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3995 let fake_model = model.as_fake();
3996
3997 cx.update(|cx| {
3998 cx.update_flags(true, vec!["subagents".to_string()]);
3999 });
4000
4001 let subagent_context = SubagentContext {
4002 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4003 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4004 depth: 1,
4005 summary_prompt: "Summarize your work".to_string(),
4006 context_low_prompt: "Context low, wrap up".to_string(),
4007 };
4008
4009 let project = thread.read_with(cx, |t, _| t.project.clone());
4010 let project_context = cx.new(|_cx| ProjectContext::default());
4011 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4012 let context_server_registry =
4013 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4014
4015 let subagent = cx.new(|cx| {
4016 Thread::new_subagent(
4017 project.clone(),
4018 project_context,
4019 context_server_registry,
4020 Templates::new(),
4021 model.clone(),
4022 subagent_context,
4023 std::collections::BTreeMap::new(),
4024 cx,
4025 )
4026 });
4027
4028 let task_prompt = "Find all TODO comments in the codebase";
4029 subagent
4030 .update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))
4031 .unwrap();
4032 cx.run_until_parked();
4033
4034 let pending = fake_model.pending_completions();
4035 assert_eq!(pending.len(), 1, "should have one pending completion");
4036
4037 let messages = &pending[0].messages;
4038 let user_messages: Vec<_> = messages
4039 .iter()
4040 .filter(|m| m.role == language_model::Role::User)
4041 .collect();
4042 assert_eq!(user_messages.len(), 1, "should have one user message");
4043
4044 let content = &user_messages[0].content[0];
4045 assert!(
4046 content.to_str().unwrap().contains("TODO"),
4047 "task prompt should be in user message"
4048 );
4049}
4050
4051#[gpui::test]
4052async fn test_subagent_returns_summary_on_completion(cx: &mut TestAppContext) {
4053 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4054 let fake_model = model.as_fake();
4055
4056 cx.update(|cx| {
4057 cx.update_flags(true, vec!["subagents".to_string()]);
4058 });
4059
4060 let subagent_context = SubagentContext {
4061 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4062 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4063 depth: 1,
4064 summary_prompt: "Please summarize what you found".to_string(),
4065 context_low_prompt: "Context low, wrap up".to_string(),
4066 };
4067
4068 let project = thread.read_with(cx, |t, _| t.project.clone());
4069 let project_context = cx.new(|_cx| ProjectContext::default());
4070 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4071 let context_server_registry =
4072 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4073
4074 let subagent = cx.new(|cx| {
4075 Thread::new_subagent(
4076 project.clone(),
4077 project_context,
4078 context_server_registry,
4079 Templates::new(),
4080 model.clone(),
4081 subagent_context,
4082 std::collections::BTreeMap::new(),
4083 cx,
4084 )
4085 });
4086
4087 subagent
4088 .update(cx, |thread, cx| {
4089 thread.submit_user_message("Do some work", cx)
4090 })
4091 .unwrap();
4092 cx.run_until_parked();
4093
4094 fake_model.send_last_completion_stream_text_chunk("I did the work");
4095 fake_model
4096 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4097 fake_model.end_last_completion_stream();
4098 cx.run_until_parked();
4099
4100 subagent
4101 .update(cx, |thread, cx| thread.request_final_summary(cx))
4102 .unwrap();
4103 cx.run_until_parked();
4104
4105 let pending = fake_model.pending_completions();
4106 assert!(
4107 !pending.is_empty(),
4108 "should have pending completion for summary"
4109 );
4110
4111 let messages = &pending.last().unwrap().messages;
4112 let user_messages: Vec<_> = messages
4113 .iter()
4114 .filter(|m| m.role == language_model::Role::User)
4115 .collect();
4116
4117 let last_user = user_messages.last().unwrap();
4118 assert!(
4119 last_user.content[0].to_str().unwrap().contains("summarize"),
4120 "summary prompt should be sent"
4121 );
4122}
4123
4124#[gpui::test]
4125async fn test_allowed_tools_restricts_subagent_capabilities(cx: &mut TestAppContext) {
4126 init_test(cx);
4127
4128 cx.update(|cx| {
4129 cx.update_flags(true, vec!["subagents".to_string()]);
4130 });
4131
4132 let fs = FakeFs::new(cx.executor());
4133 fs.insert_tree(path!("/test"), json!({})).await;
4134 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4135 let project_context = cx.new(|_cx| ProjectContext::default());
4136 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4137 let context_server_registry =
4138 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4139 let model = Arc::new(FakeLanguageModel::default());
4140
4141 let subagent_context = SubagentContext {
4142 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4143 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4144 depth: 1,
4145 summary_prompt: "Summarize".to_string(),
4146 context_low_prompt: "Context low".to_string(),
4147 };
4148
4149 let subagent = cx.new(|cx| {
4150 let mut thread = Thread::new_subagent(
4151 project.clone(),
4152 project_context,
4153 context_server_registry,
4154 Templates::new(),
4155 model.clone(),
4156 subagent_context,
4157 std::collections::BTreeMap::new(),
4158 cx,
4159 );
4160 thread.add_tool(EchoTool);
4161 thread.add_tool(DelayTool);
4162 thread.add_tool(WordListTool);
4163 thread
4164 });
4165
4166 subagent.read_with(cx, |thread, _| {
4167 assert!(thread.has_registered_tool("echo"));
4168 assert!(thread.has_registered_tool("delay"));
4169 assert!(thread.has_registered_tool("word_list"));
4170 });
4171
4172 let allowed: collections::HashSet<gpui::SharedString> =
4173 vec!["echo".into()].into_iter().collect();
4174
4175 subagent.update(cx, |thread, _cx| {
4176 thread.restrict_tools(&allowed);
4177 });
4178
4179 subagent.read_with(cx, |thread, _| {
4180 assert!(
4181 thread.has_registered_tool("echo"),
4182 "echo should still be available"
4183 );
4184 assert!(
4185 !thread.has_registered_tool("delay"),
4186 "delay should be removed"
4187 );
4188 assert!(
4189 !thread.has_registered_tool("word_list"),
4190 "word_list should be removed"
4191 );
4192 });
4193}
4194
4195#[gpui::test]
4196async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4197 init_test(cx);
4198
4199 cx.update(|cx| {
4200 cx.update_flags(true, vec!["subagents".to_string()]);
4201 });
4202
4203 let fs = FakeFs::new(cx.executor());
4204 fs.insert_tree(path!("/test"), json!({})).await;
4205 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4206 let project_context = cx.new(|_cx| ProjectContext::default());
4207 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4208 let context_server_registry =
4209 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4210 let model = Arc::new(FakeLanguageModel::default());
4211
4212 let parent = cx.new(|cx| {
4213 Thread::new(
4214 project.clone(),
4215 project_context.clone(),
4216 context_server_registry.clone(),
4217 Templates::new(),
4218 Some(model.clone()),
4219 cx,
4220 )
4221 });
4222
4223 let subagent_context = SubagentContext {
4224 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4225 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4226 depth: 1,
4227 summary_prompt: "Summarize".to_string(),
4228 context_low_prompt: "Context low".to_string(),
4229 };
4230
4231 let subagent = cx.new(|cx| {
4232 Thread::new_subagent(
4233 project.clone(),
4234 project_context.clone(),
4235 context_server_registry.clone(),
4236 Templates::new(),
4237 model.clone(),
4238 subagent_context,
4239 std::collections::BTreeMap::new(),
4240 cx,
4241 )
4242 });
4243
4244 parent.update(cx, |thread, _cx| {
4245 thread.register_running_subagent(subagent.downgrade());
4246 });
4247
4248 subagent
4249 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4250 .unwrap();
4251 cx.run_until_parked();
4252
4253 subagent.read_with(cx, |thread, _| {
4254 assert!(!thread.is_turn_complete(), "subagent should be running");
4255 });
4256
4257 parent.update(cx, |thread, cx| {
4258 thread.cancel(cx).detach();
4259 });
4260
4261 subagent.read_with(cx, |thread, _| {
4262 assert!(
4263 thread.is_turn_complete(),
4264 "subagent should be cancelled when parent cancels"
4265 );
4266 });
4267}
4268
4269#[gpui::test]
4270async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4271 // This test verifies that the subagent tool properly handles user cancellation
4272 // via `event_stream.cancelled_by_user()` and stops all running subagents.
4273 init_test(cx);
4274 always_allow_tools(cx);
4275
4276 cx.update(|cx| {
4277 cx.update_flags(true, vec!["subagents".to_string()]);
4278 });
4279
4280 let fs = FakeFs::new(cx.executor());
4281 fs.insert_tree(path!("/test"), json!({})).await;
4282 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4283 let project_context = cx.new(|_cx| ProjectContext::default());
4284 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4285 let context_server_registry =
4286 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4287 let model = Arc::new(FakeLanguageModel::default());
4288
4289 let parent = cx.new(|cx| {
4290 Thread::new(
4291 project.clone(),
4292 project_context.clone(),
4293 context_server_registry.clone(),
4294 Templates::new(),
4295 Some(model.clone()),
4296 cx,
4297 )
4298 });
4299
4300 #[allow(clippy::arc_with_non_send_sync)]
4301 let tool = Arc::new(SubagentTool::new(parent.downgrade(), 0));
4302
4303 let (event_stream, _rx, mut cancellation_tx) =
4304 crate::ToolCallEventStream::test_with_cancellation();
4305
4306 // Start the subagent tool
4307 let task = cx.update(|cx| {
4308 tool.run(
4309 SubagentToolInput {
4310 label: "Long running task".to_string(),
4311 task_prompt: "Do a very long task that takes forever".to_string(),
4312 summary_prompt: "Summarize".to_string(),
4313 context_low_prompt: "Context low".to_string(),
4314 timeout_ms: None,
4315 allowed_tools: None,
4316 },
4317 event_stream.clone(),
4318 cx,
4319 )
4320 });
4321
4322 cx.run_until_parked();
4323
4324 // Signal cancellation via the event stream
4325 crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4326
4327 // The task should complete promptly with a cancellation error
4328 let timeout = cx.background_executor.timer(Duration::from_secs(5));
4329 let result = futures::select! {
4330 result = task.fuse() => result,
4331 _ = timeout.fuse() => {
4332 panic!("subagent tool did not respond to cancellation within timeout");
4333 }
4334 };
4335
4336 // Verify we got a cancellation error
4337 let err = result.unwrap_err();
4338 assert!(
4339 err.to_string().contains("cancelled by user"),
4340 "expected cancellation error, got: {}",
4341 err
4342 );
4343}
4344
4345#[gpui::test]
4346async fn test_subagent_model_error_returned_as_tool_error(cx: &mut TestAppContext) {
4347 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4348 let fake_model = model.as_fake();
4349
4350 cx.update(|cx| {
4351 cx.update_flags(true, vec!["subagents".to_string()]);
4352 });
4353
4354 let subagent_context = SubagentContext {
4355 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4356 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4357 depth: 1,
4358 summary_prompt: "Summarize".to_string(),
4359 context_low_prompt: "Context low".to_string(),
4360 };
4361
4362 let project = thread.read_with(cx, |t, _| t.project.clone());
4363 let project_context = cx.new(|_cx| ProjectContext::default());
4364 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4365 let context_server_registry =
4366 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4367
4368 let subagent = cx.new(|cx| {
4369 Thread::new_subagent(
4370 project.clone(),
4371 project_context,
4372 context_server_registry,
4373 Templates::new(),
4374 model.clone(),
4375 subagent_context,
4376 std::collections::BTreeMap::new(),
4377 cx,
4378 )
4379 });
4380
4381 subagent
4382 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4383 .unwrap();
4384 cx.run_until_parked();
4385
4386 subagent.read_with(cx, |thread, _| {
4387 assert!(!thread.is_turn_complete(), "turn should be in progress");
4388 });
4389
4390 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::NoApiKey {
4391 provider: LanguageModelProviderName::from("Fake".to_string()),
4392 });
4393 fake_model.end_last_completion_stream();
4394 cx.run_until_parked();
4395
4396 subagent.read_with(cx, |thread, _| {
4397 assert!(
4398 thread.is_turn_complete(),
4399 "turn should be complete after non-retryable error"
4400 );
4401 });
4402}
4403
4404#[gpui::test]
4405async fn test_subagent_timeout_triggers_early_summary(cx: &mut TestAppContext) {
4406 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4407 let fake_model = model.as_fake();
4408
4409 cx.update(|cx| {
4410 cx.update_flags(true, vec!["subagents".to_string()]);
4411 });
4412
4413 let subagent_context = SubagentContext {
4414 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4415 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4416 depth: 1,
4417 summary_prompt: "Summarize your work".to_string(),
4418 context_low_prompt: "Context low, stop and summarize".to_string(),
4419 };
4420
4421 let project = thread.read_with(cx, |t, _| t.project.clone());
4422 let project_context = cx.new(|_cx| ProjectContext::default());
4423 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4424 let context_server_registry =
4425 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4426
4427 let subagent = cx.new(|cx| {
4428 Thread::new_subagent(
4429 project.clone(),
4430 project_context.clone(),
4431 context_server_registry.clone(),
4432 Templates::new(),
4433 model.clone(),
4434 subagent_context.clone(),
4435 std::collections::BTreeMap::new(),
4436 cx,
4437 )
4438 });
4439
4440 subagent.update(cx, |thread, _| {
4441 thread.add_tool(EchoTool);
4442 });
4443
4444 subagent
4445 .update(cx, |thread, cx| {
4446 thread.submit_user_message("Do some work", cx)
4447 })
4448 .unwrap();
4449 cx.run_until_parked();
4450
4451 fake_model.send_last_completion_stream_text_chunk("Working on it...");
4452 fake_model
4453 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4454 fake_model.end_last_completion_stream();
4455 cx.run_until_parked();
4456
4457 let interrupt_result = subagent.update(cx, |thread, cx| thread.interrupt_for_summary(cx));
4458 assert!(
4459 interrupt_result.is_ok(),
4460 "interrupt_for_summary should succeed"
4461 );
4462
4463 cx.run_until_parked();
4464
4465 let pending = fake_model.pending_completions();
4466 assert!(
4467 !pending.is_empty(),
4468 "should have pending completion for interrupted summary"
4469 );
4470
4471 let messages = &pending.last().unwrap().messages;
4472 let user_messages: Vec<_> = messages
4473 .iter()
4474 .filter(|m| m.role == language_model::Role::User)
4475 .collect();
4476
4477 let last_user = user_messages.last().unwrap();
4478 let content_str = last_user.content[0].to_str().unwrap();
4479 assert!(
4480 content_str.contains("Context low") || content_str.contains("stop and summarize"),
4481 "context_low_prompt should be sent when interrupting: got {:?}",
4482 content_str
4483 );
4484}
4485
4486#[gpui::test]
4487async fn test_context_low_check_returns_true_when_usage_high(cx: &mut TestAppContext) {
4488 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4489 let fake_model = model.as_fake();
4490
4491 cx.update(|cx| {
4492 cx.update_flags(true, vec!["subagents".to_string()]);
4493 });
4494
4495 let subagent_context = SubagentContext {
4496 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4497 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4498 depth: 1,
4499 summary_prompt: "Summarize".to_string(),
4500 context_low_prompt: "Context low".to_string(),
4501 };
4502
4503 let project = thread.read_with(cx, |t, _| t.project.clone());
4504 let project_context = cx.new(|_cx| ProjectContext::default());
4505 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4506 let context_server_registry =
4507 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4508
4509 let subagent = cx.new(|cx| {
4510 Thread::new_subagent(
4511 project.clone(),
4512 project_context,
4513 context_server_registry,
4514 Templates::new(),
4515 model.clone(),
4516 subagent_context,
4517 std::collections::BTreeMap::new(),
4518 cx,
4519 )
4520 });
4521
4522 subagent
4523 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4524 .unwrap();
4525 cx.run_until_parked();
4526
4527 let max_tokens = model.max_token_count();
4528 let high_usage = language_model::TokenUsage {
4529 input_tokens: (max_tokens as f64 * 0.80) as u64,
4530 output_tokens: 0,
4531 cache_creation_input_tokens: 0,
4532 cache_read_input_tokens: 0,
4533 };
4534
4535 fake_model
4536 .send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(high_usage));
4537 fake_model.send_last_completion_stream_text_chunk("Working...");
4538 fake_model
4539 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4540 fake_model.end_last_completion_stream();
4541 cx.run_until_parked();
4542
4543 let usage = subagent.read_with(cx, |thread, _| thread.latest_token_usage());
4544 assert!(usage.is_some(), "should have token usage after completion");
4545
4546 let usage = usage.unwrap();
4547 let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
4548 assert!(
4549 remaining_ratio <= 0.25,
4550 "remaining ratio should be at or below 25% (got {}%), indicating context is low",
4551 remaining_ratio * 100.0
4552 );
4553}
4554
4555#[gpui::test]
4556async fn test_allowed_tools_rejects_unknown_tool(cx: &mut TestAppContext) {
4557 init_test(cx);
4558
4559 cx.update(|cx| {
4560 cx.update_flags(true, vec!["subagents".to_string()]);
4561 });
4562
4563 let fs = FakeFs::new(cx.executor());
4564 fs.insert_tree(path!("/test"), json!({})).await;
4565 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4566 let project_context = cx.new(|_cx| ProjectContext::default());
4567 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4568 let context_server_registry =
4569 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4570 let model = Arc::new(FakeLanguageModel::default());
4571
4572 let parent = cx.new(|cx| {
4573 let mut thread = Thread::new(
4574 project.clone(),
4575 project_context.clone(),
4576 context_server_registry.clone(),
4577 Templates::new(),
4578 Some(model.clone()),
4579 cx,
4580 );
4581 thread.add_tool(EchoTool);
4582 thread
4583 });
4584
4585 #[allow(clippy::arc_with_non_send_sync)]
4586 let tool = Arc::new(SubagentTool::new(parent.downgrade(), 0));
4587
4588 let allowed_tools = Some(vec!["nonexistent_tool".to_string()]);
4589 let result = cx.read(|cx| tool.validate_allowed_tools(&allowed_tools, cx));
4590
4591 assert!(result.is_err(), "should reject unknown tool");
4592 let err_msg = result.unwrap_err().to_string();
4593 assert!(
4594 err_msg.contains("nonexistent_tool"),
4595 "error should mention the invalid tool name: {}",
4596 err_msg
4597 );
4598 assert!(
4599 err_msg.contains("do not exist"),
4600 "error should explain the tool does not exist: {}",
4601 err_msg
4602 );
4603}
4604
4605#[gpui::test]
4606async fn test_subagent_empty_response_handled(cx: &mut TestAppContext) {
4607 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4608 let fake_model = model.as_fake();
4609
4610 cx.update(|cx| {
4611 cx.update_flags(true, vec!["subagents".to_string()]);
4612 });
4613
4614 let subagent_context = SubagentContext {
4615 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4616 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4617 depth: 1,
4618 summary_prompt: "Summarize".to_string(),
4619 context_low_prompt: "Context low".to_string(),
4620 };
4621
4622 let project = thread.read_with(cx, |t, _| t.project.clone());
4623 let project_context = cx.new(|_cx| ProjectContext::default());
4624 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4625 let context_server_registry =
4626 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4627
4628 let subagent = cx.new(|cx| {
4629 Thread::new_subagent(
4630 project.clone(),
4631 project_context,
4632 context_server_registry,
4633 Templates::new(),
4634 model.clone(),
4635 subagent_context,
4636 std::collections::BTreeMap::new(),
4637 cx,
4638 )
4639 });
4640
4641 subagent
4642 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4643 .unwrap();
4644 cx.run_until_parked();
4645
4646 fake_model
4647 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4648 fake_model.end_last_completion_stream();
4649 cx.run_until_parked();
4650
4651 subagent.read_with(cx, |thread, _| {
4652 assert!(
4653 thread.is_turn_complete(),
4654 "turn should complete even with empty response"
4655 );
4656 });
4657}
4658
4659#[gpui::test]
4660async fn test_nested_subagent_at_depth_2_succeeds(cx: &mut TestAppContext) {
4661 init_test(cx);
4662
4663 cx.update(|cx| {
4664 cx.update_flags(true, vec!["subagents".to_string()]);
4665 });
4666
4667 let fs = FakeFs::new(cx.executor());
4668 fs.insert_tree(path!("/test"), json!({})).await;
4669 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4670 let project_context = cx.new(|_cx| ProjectContext::default());
4671 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4672 let context_server_registry =
4673 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4674 let model = Arc::new(FakeLanguageModel::default());
4675
4676 let depth_1_context = SubagentContext {
4677 parent_thread_id: agent_client_protocol::SessionId::new("root-id"),
4678 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-1"),
4679 depth: 1,
4680 summary_prompt: "Summarize".to_string(),
4681 context_low_prompt: "Context low".to_string(),
4682 };
4683
4684 let depth_1_subagent = cx.new(|cx| {
4685 Thread::new_subagent(
4686 project.clone(),
4687 project_context.clone(),
4688 context_server_registry.clone(),
4689 Templates::new(),
4690 model.clone(),
4691 depth_1_context,
4692 std::collections::BTreeMap::new(),
4693 cx,
4694 )
4695 });
4696
4697 depth_1_subagent.read_with(cx, |thread, _| {
4698 assert_eq!(thread.depth(), 1);
4699 assert!(thread.is_subagent());
4700 });
4701
4702 let depth_2_context = SubagentContext {
4703 parent_thread_id: agent_client_protocol::SessionId::new("depth-1-id"),
4704 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-2"),
4705 depth: 2,
4706 summary_prompt: "Summarize depth 2".to_string(),
4707 context_low_prompt: "Context low depth 2".to_string(),
4708 };
4709
4710 let depth_2_subagent = cx.new(|cx| {
4711 Thread::new_subagent(
4712 project.clone(),
4713 project_context.clone(),
4714 context_server_registry.clone(),
4715 Templates::new(),
4716 model.clone(),
4717 depth_2_context,
4718 std::collections::BTreeMap::new(),
4719 cx,
4720 )
4721 });
4722
4723 depth_2_subagent.read_with(cx, |thread, _| {
4724 assert_eq!(thread.depth(), 2);
4725 assert!(thread.is_subagent());
4726 });
4727
4728 depth_2_subagent
4729 .update(cx, |thread, cx| {
4730 thread.submit_user_message("Nested task", cx)
4731 })
4732 .unwrap();
4733 cx.run_until_parked();
4734
4735 let pending = model.as_fake().pending_completions();
4736 assert!(
4737 !pending.is_empty(),
4738 "depth-2 subagent should be able to submit messages"
4739 );
4740}
4741
4742#[gpui::test]
4743async fn test_subagent_uses_tool_and_returns_result(cx: &mut TestAppContext) {
4744 init_test(cx);
4745 always_allow_tools(cx);
4746
4747 cx.update(|cx| {
4748 cx.update_flags(true, vec!["subagents".to_string()]);
4749 });
4750
4751 let fs = FakeFs::new(cx.executor());
4752 fs.insert_tree(path!("/test"), json!({})).await;
4753 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4754 let project_context = cx.new(|_cx| ProjectContext::default());
4755 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4756 let context_server_registry =
4757 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4758 let model = Arc::new(FakeLanguageModel::default());
4759 let fake_model = model.as_fake();
4760
4761 let subagent_context = SubagentContext {
4762 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4763 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4764 depth: 1,
4765 summary_prompt: "Summarize what you did".to_string(),
4766 context_low_prompt: "Context low".to_string(),
4767 };
4768
4769 let subagent = cx.new(|cx| {
4770 let mut thread = Thread::new_subagent(
4771 project.clone(),
4772 project_context,
4773 context_server_registry,
4774 Templates::new(),
4775 model.clone(),
4776 subagent_context,
4777 std::collections::BTreeMap::new(),
4778 cx,
4779 );
4780 thread.add_tool(EchoTool);
4781 thread
4782 });
4783
4784 subagent.read_with(cx, |thread, _| {
4785 assert!(
4786 thread.has_registered_tool("echo"),
4787 "subagent should have echo tool"
4788 );
4789 });
4790
4791 subagent
4792 .update(cx, |thread, cx| {
4793 thread.submit_user_message("Use the echo tool to echo 'hello world'", cx)
4794 })
4795 .unwrap();
4796 cx.run_until_parked();
4797
4798 let tool_use = LanguageModelToolUse {
4799 id: "tool_call_1".into(),
4800 name: EchoTool::NAME.into(),
4801 raw_input: json!({"text": "hello world"}).to_string(),
4802 input: json!({"text": "hello world"}),
4803 is_input_complete: true,
4804 thought_signature: None,
4805 };
4806 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
4807 fake_model.end_last_completion_stream();
4808 cx.run_until_parked();
4809
4810 let pending = fake_model.pending_completions();
4811 assert!(
4812 !pending.is_empty(),
4813 "should have pending completion after tool use"
4814 );
4815
4816 let last_completion = pending.last().unwrap();
4817 let has_tool_result = last_completion.messages.iter().any(|m| {
4818 m.content
4819 .iter()
4820 .any(|c| matches!(c, MessageContent::ToolResult(_)))
4821 });
4822 assert!(
4823 has_tool_result,
4824 "tool result should be in the messages sent back to the model"
4825 );
4826}
4827
4828#[gpui::test]
4829async fn test_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4830 init_test(cx);
4831
4832 cx.update(|cx| {
4833 cx.update_flags(true, vec!["subagents".to_string()]);
4834 });
4835
4836 let fs = FakeFs::new(cx.executor());
4837 fs.insert_tree(path!("/test"), json!({})).await;
4838 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4839 let project_context = cx.new(|_cx| ProjectContext::default());
4840 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4841 let context_server_registry =
4842 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4843 let model = Arc::new(FakeLanguageModel::default());
4844
4845 let parent = cx.new(|cx| {
4846 Thread::new(
4847 project.clone(),
4848 project_context.clone(),
4849 context_server_registry.clone(),
4850 Templates::new(),
4851 Some(model.clone()),
4852 cx,
4853 )
4854 });
4855
4856 let mut subagents = Vec::new();
4857 for i in 0..MAX_PARALLEL_SUBAGENTS {
4858 let subagent_context = SubagentContext {
4859 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4860 tool_use_id: language_model::LanguageModelToolUseId::from(format!("tool-use-{}", i)),
4861 depth: 1,
4862 summary_prompt: "Summarize".to_string(),
4863 context_low_prompt: "Context low".to_string(),
4864 };
4865
4866 let subagent = cx.new(|cx| {
4867 Thread::new_subagent(
4868 project.clone(),
4869 project_context.clone(),
4870 context_server_registry.clone(),
4871 Templates::new(),
4872 model.clone(),
4873 subagent_context,
4874 std::collections::BTreeMap::new(),
4875 cx,
4876 )
4877 });
4878
4879 parent.update(cx, |thread, _cx| {
4880 thread.register_running_subagent(subagent.downgrade());
4881 });
4882 subagents.push(subagent);
4883 }
4884
4885 parent.read_with(cx, |thread, _| {
4886 assert_eq!(
4887 thread.running_subagent_count(),
4888 MAX_PARALLEL_SUBAGENTS,
4889 "should have MAX_PARALLEL_SUBAGENTS registered"
4890 );
4891 });
4892
4893 #[allow(clippy::arc_with_non_send_sync)]
4894 let tool = Arc::new(SubagentTool::new(parent.downgrade(), 0));
4895
4896 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4897
4898 let result = cx.update(|cx| {
4899 tool.run(
4900 SubagentToolInput {
4901 label: "Test".to_string(),
4902 task_prompt: "Do something".to_string(),
4903 summary_prompt: "Summarize".to_string(),
4904 context_low_prompt: "Context low".to_string(),
4905 timeout_ms: None,
4906 allowed_tools: None,
4907 },
4908 event_stream,
4909 cx,
4910 )
4911 });
4912
4913 let err = result.await.unwrap_err();
4914 assert!(
4915 err.to_string().contains("Maximum parallel subagents"),
4916 "should reject when max parallel subagents reached: {}",
4917 err
4918 );
4919
4920 drop(subagents);
4921}
4922
4923#[gpui::test]
4924async fn test_subagent_tool_end_to_end(cx: &mut TestAppContext) {
4925 init_test(cx);
4926 always_allow_tools(cx);
4927
4928 cx.update(|cx| {
4929 cx.update_flags(true, vec!["subagents".to_string()]);
4930 });
4931
4932 let fs = FakeFs::new(cx.executor());
4933 fs.insert_tree(path!("/test"), json!({})).await;
4934 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4935 let project_context = cx.new(|_cx| ProjectContext::default());
4936 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4937 let context_server_registry =
4938 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4939 let model = Arc::new(FakeLanguageModel::default());
4940 let fake_model = model.as_fake();
4941
4942 let parent = cx.new(|cx| {
4943 let mut thread = Thread::new(
4944 project.clone(),
4945 project_context.clone(),
4946 context_server_registry.clone(),
4947 Templates::new(),
4948 Some(model.clone()),
4949 cx,
4950 );
4951 thread.add_tool(EchoTool);
4952 thread
4953 });
4954
4955 #[allow(clippy::arc_with_non_send_sync)]
4956 let tool = Arc::new(SubagentTool::new(parent.downgrade(), 0));
4957
4958 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4959
4960 let task = cx.update(|cx| {
4961 tool.run(
4962 SubagentToolInput {
4963 label: "Research task".to_string(),
4964 task_prompt: "Find all TODOs in the codebase".to_string(),
4965 summary_prompt: "Summarize what you found".to_string(),
4966 context_low_prompt: "Context low, wrap up".to_string(),
4967 timeout_ms: None,
4968 allowed_tools: None,
4969 },
4970 event_stream,
4971 cx,
4972 )
4973 });
4974
4975 cx.run_until_parked();
4976
4977 let pending = fake_model.pending_completions();
4978 assert!(
4979 !pending.is_empty(),
4980 "subagent should have started and sent a completion request"
4981 );
4982
4983 let first_completion = &pending[0];
4984 let has_task_prompt = first_completion.messages.iter().any(|m| {
4985 m.role == language_model::Role::User
4986 && m.content
4987 .iter()
4988 .any(|c| c.to_str().map(|s| s.contains("TODO")).unwrap_or(false))
4989 });
4990 assert!(has_task_prompt, "task prompt should be sent to subagent");
4991
4992 fake_model.send_last_completion_stream_text_chunk("I found 5 TODOs in the codebase.");
4993 fake_model
4994 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4995 fake_model.end_last_completion_stream();
4996 cx.run_until_parked();
4997
4998 let pending = fake_model.pending_completions();
4999 assert!(
5000 !pending.is_empty(),
5001 "should have pending completion for summary request"
5002 );
5003
5004 let last_completion = pending.last().unwrap();
5005 let has_summary_prompt = last_completion.messages.iter().any(|m| {
5006 m.role == language_model::Role::User
5007 && m.content.iter().any(|c| {
5008 c.to_str()
5009 .map(|s| s.contains("Summarize") || s.contains("summarize"))
5010 .unwrap_or(false)
5011 })
5012 });
5013 assert!(
5014 has_summary_prompt,
5015 "summary prompt should be sent after task completion"
5016 );
5017
5018 fake_model.send_last_completion_stream_text_chunk("Summary: Found 5 TODOs across 3 files.");
5019 fake_model
5020 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
5021 fake_model.end_last_completion_stream();
5022 cx.run_until_parked();
5023
5024 let result = task.await;
5025 assert!(result.is_ok(), "subagent tool should complete successfully");
5026
5027 let summary = result.unwrap();
5028 assert!(
5029 summary.contains("Summary") || summary.contains("TODO") || summary.contains("5"),
5030 "summary should contain subagent's response: {}",
5031 summary
5032 );
5033}
5034
5035#[gpui::test]
5036async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5037 init_test(cx);
5038
5039 let fs = FakeFs::new(cx.executor());
5040 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5041 .await;
5042 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5043
5044 cx.update(|cx| {
5045 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5046 settings.tool_permissions.tools.insert(
5047 EditFileTool::NAME.into(),
5048 agent_settings::ToolRules {
5049 default_mode: settings::ToolPermissionMode::Allow,
5050 always_allow: vec![],
5051 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5052 always_confirm: vec![],
5053 invalid_patterns: vec![],
5054 },
5055 );
5056 agent_settings::AgentSettings::override_global(settings, cx);
5057 });
5058
5059 let context_server_registry =
5060 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5061 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5062 let templates = crate::Templates::new();
5063 let thread = cx.new(|cx| {
5064 crate::Thread::new(
5065 project.clone(),
5066 cx.new(|_cx| prompt_store::ProjectContext::default()),
5067 context_server_registry,
5068 templates.clone(),
5069 None,
5070 cx,
5071 )
5072 });
5073
5074 #[allow(clippy::arc_with_non_send_sync)]
5075 let tool = Arc::new(crate::EditFileTool::new(
5076 project.clone(),
5077 thread.downgrade(),
5078 language_registry,
5079 templates,
5080 ));
5081 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5082
5083 let task = cx.update(|cx| {
5084 tool.run(
5085 crate::EditFileToolInput {
5086 display_description: "Edit sensitive file".to_string(),
5087 path: "root/sensitive_config.txt".into(),
5088 mode: crate::EditFileMode::Edit,
5089 },
5090 event_stream,
5091 cx,
5092 )
5093 });
5094
5095 let result = task.await;
5096 assert!(result.is_err(), "expected edit to be blocked");
5097 assert!(
5098 result.unwrap_err().to_string().contains("blocked"),
5099 "error should mention the edit was blocked"
5100 );
5101}
5102
5103#[gpui::test]
5104async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5105 init_test(cx);
5106
5107 let fs = FakeFs::new(cx.executor());
5108 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5109 .await;
5110 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5111
5112 cx.update(|cx| {
5113 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5114 settings.tool_permissions.tools.insert(
5115 DeletePathTool::NAME.into(),
5116 agent_settings::ToolRules {
5117 default_mode: settings::ToolPermissionMode::Allow,
5118 always_allow: vec![],
5119 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5120 always_confirm: vec![],
5121 invalid_patterns: vec![],
5122 },
5123 );
5124 agent_settings::AgentSettings::override_global(settings, cx);
5125 });
5126
5127 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5128
5129 #[allow(clippy::arc_with_non_send_sync)]
5130 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5131 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5132
5133 let task = cx.update(|cx| {
5134 tool.run(
5135 crate::DeletePathToolInput {
5136 path: "root/important_data.txt".to_string(),
5137 },
5138 event_stream,
5139 cx,
5140 )
5141 });
5142
5143 let result = task.await;
5144 assert!(result.is_err(), "expected deletion to be blocked");
5145 assert!(
5146 result.unwrap_err().to_string().contains("blocked"),
5147 "error should mention the deletion was blocked"
5148 );
5149}
5150
5151#[gpui::test]
5152async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5153 init_test(cx);
5154
5155 let fs = FakeFs::new(cx.executor());
5156 fs.insert_tree(
5157 "/root",
5158 json!({
5159 "safe.txt": "content",
5160 "protected": {}
5161 }),
5162 )
5163 .await;
5164 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5165
5166 cx.update(|cx| {
5167 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5168 settings.tool_permissions.tools.insert(
5169 MovePathTool::NAME.into(),
5170 agent_settings::ToolRules {
5171 default_mode: settings::ToolPermissionMode::Allow,
5172 always_allow: vec![],
5173 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5174 always_confirm: vec![],
5175 invalid_patterns: vec![],
5176 },
5177 );
5178 agent_settings::AgentSettings::override_global(settings, cx);
5179 });
5180
5181 #[allow(clippy::arc_with_non_send_sync)]
5182 let tool = Arc::new(crate::MovePathTool::new(project));
5183 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5184
5185 let task = cx.update(|cx| {
5186 tool.run(
5187 crate::MovePathToolInput {
5188 source_path: "root/safe.txt".to_string(),
5189 destination_path: "root/protected/safe.txt".to_string(),
5190 },
5191 event_stream,
5192 cx,
5193 )
5194 });
5195
5196 let result = task.await;
5197 assert!(
5198 result.is_err(),
5199 "expected move to be blocked due to destination path"
5200 );
5201 assert!(
5202 result.unwrap_err().to_string().contains("blocked"),
5203 "error should mention the move was blocked"
5204 );
5205}
5206
5207#[gpui::test]
5208async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5209 init_test(cx);
5210
5211 let fs = FakeFs::new(cx.executor());
5212 fs.insert_tree(
5213 "/root",
5214 json!({
5215 "secret.txt": "secret content",
5216 "public": {}
5217 }),
5218 )
5219 .await;
5220 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5221
5222 cx.update(|cx| {
5223 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5224 settings.tool_permissions.tools.insert(
5225 MovePathTool::NAME.into(),
5226 agent_settings::ToolRules {
5227 default_mode: settings::ToolPermissionMode::Allow,
5228 always_allow: vec![],
5229 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5230 always_confirm: vec![],
5231 invalid_patterns: vec![],
5232 },
5233 );
5234 agent_settings::AgentSettings::override_global(settings, cx);
5235 });
5236
5237 #[allow(clippy::arc_with_non_send_sync)]
5238 let tool = Arc::new(crate::MovePathTool::new(project));
5239 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5240
5241 let task = cx.update(|cx| {
5242 tool.run(
5243 crate::MovePathToolInput {
5244 source_path: "root/secret.txt".to_string(),
5245 destination_path: "root/public/not_secret.txt".to_string(),
5246 },
5247 event_stream,
5248 cx,
5249 )
5250 });
5251
5252 let result = task.await;
5253 assert!(
5254 result.is_err(),
5255 "expected move to be blocked due to source path"
5256 );
5257 assert!(
5258 result.unwrap_err().to_string().contains("blocked"),
5259 "error should mention the move was blocked"
5260 );
5261}
5262
5263#[gpui::test]
5264async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5265 init_test(cx);
5266
5267 let fs = FakeFs::new(cx.executor());
5268 fs.insert_tree(
5269 "/root",
5270 json!({
5271 "confidential.txt": "confidential data",
5272 "dest": {}
5273 }),
5274 )
5275 .await;
5276 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5277
5278 cx.update(|cx| {
5279 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5280 settings.tool_permissions.tools.insert(
5281 CopyPathTool::NAME.into(),
5282 agent_settings::ToolRules {
5283 default_mode: settings::ToolPermissionMode::Allow,
5284 always_allow: vec![],
5285 always_deny: vec![
5286 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5287 ],
5288 always_confirm: vec![],
5289 invalid_patterns: vec![],
5290 },
5291 );
5292 agent_settings::AgentSettings::override_global(settings, cx);
5293 });
5294
5295 #[allow(clippy::arc_with_non_send_sync)]
5296 let tool = Arc::new(crate::CopyPathTool::new(project));
5297 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5298
5299 let task = cx.update(|cx| {
5300 tool.run(
5301 crate::CopyPathToolInput {
5302 source_path: "root/confidential.txt".to_string(),
5303 destination_path: "root/dest/copy.txt".to_string(),
5304 },
5305 event_stream,
5306 cx,
5307 )
5308 });
5309
5310 let result = task.await;
5311 assert!(result.is_err(), "expected copy to be blocked");
5312 assert!(
5313 result.unwrap_err().to_string().contains("blocked"),
5314 "error should mention the copy was blocked"
5315 );
5316}
5317
5318#[gpui::test]
5319async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5320 init_test(cx);
5321
5322 let fs = FakeFs::new(cx.executor());
5323 fs.insert_tree(
5324 "/root",
5325 json!({
5326 "normal.txt": "normal content",
5327 "readonly": {
5328 "config.txt": "readonly content"
5329 }
5330 }),
5331 )
5332 .await;
5333 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5334
5335 cx.update(|cx| {
5336 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5337 settings.tool_permissions.tools.insert(
5338 SaveFileTool::NAME.into(),
5339 agent_settings::ToolRules {
5340 default_mode: settings::ToolPermissionMode::Allow,
5341 always_allow: vec![],
5342 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5343 always_confirm: vec![],
5344 invalid_patterns: vec![],
5345 },
5346 );
5347 agent_settings::AgentSettings::override_global(settings, cx);
5348 });
5349
5350 #[allow(clippy::arc_with_non_send_sync)]
5351 let tool = Arc::new(crate::SaveFileTool::new(project));
5352 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5353
5354 let task = cx.update(|cx| {
5355 tool.run(
5356 crate::SaveFileToolInput {
5357 paths: vec![
5358 std::path::PathBuf::from("root/normal.txt"),
5359 std::path::PathBuf::from("root/readonly/config.txt"),
5360 ],
5361 },
5362 event_stream,
5363 cx,
5364 )
5365 });
5366
5367 let result = task.await;
5368 assert!(
5369 result.is_err(),
5370 "expected save to be blocked due to denied path"
5371 );
5372 assert!(
5373 result.unwrap_err().to_string().contains("blocked"),
5374 "error should mention the save was blocked"
5375 );
5376}
5377
5378#[gpui::test]
5379async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5380 init_test(cx);
5381
5382 let fs = FakeFs::new(cx.executor());
5383 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5384 .await;
5385 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5386
5387 cx.update(|cx| {
5388 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5389 settings.always_allow_tool_actions = false;
5390 settings.tool_permissions.tools.insert(
5391 SaveFileTool::NAME.into(),
5392 agent_settings::ToolRules {
5393 default_mode: settings::ToolPermissionMode::Allow,
5394 always_allow: vec![],
5395 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5396 always_confirm: vec![],
5397 invalid_patterns: vec![],
5398 },
5399 );
5400 agent_settings::AgentSettings::override_global(settings, cx);
5401 });
5402
5403 #[allow(clippy::arc_with_non_send_sync)]
5404 let tool = Arc::new(crate::SaveFileTool::new(project));
5405 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5406
5407 let task = cx.update(|cx| {
5408 tool.run(
5409 crate::SaveFileToolInput {
5410 paths: vec![std::path::PathBuf::from("root/config.secret")],
5411 },
5412 event_stream,
5413 cx,
5414 )
5415 });
5416
5417 let result = task.await;
5418 assert!(result.is_err(), "expected save to be blocked");
5419 assert!(
5420 result.unwrap_err().to_string().contains("blocked"),
5421 "error should mention the save was blocked"
5422 );
5423}
5424
5425#[gpui::test]
5426async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5427 init_test(cx);
5428
5429 cx.update(|cx| {
5430 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5431 settings.tool_permissions.tools.insert(
5432 WebSearchTool::NAME.into(),
5433 agent_settings::ToolRules {
5434 default_mode: settings::ToolPermissionMode::Allow,
5435 always_allow: vec![],
5436 always_deny: vec![
5437 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5438 ],
5439 always_confirm: vec![],
5440 invalid_patterns: vec![],
5441 },
5442 );
5443 agent_settings::AgentSettings::override_global(settings, cx);
5444 });
5445
5446 #[allow(clippy::arc_with_non_send_sync)]
5447 let tool = Arc::new(crate::WebSearchTool);
5448 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5449
5450 let input: crate::WebSearchToolInput =
5451 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5452
5453 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5454
5455 let result = task.await;
5456 assert!(result.is_err(), "expected search to be blocked");
5457 assert!(
5458 result.unwrap_err().to_string().contains("blocked"),
5459 "error should mention the search was blocked"
5460 );
5461}
5462
5463#[gpui::test]
5464async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5465 init_test(cx);
5466
5467 let fs = FakeFs::new(cx.executor());
5468 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5469 .await;
5470 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5471
5472 cx.update(|cx| {
5473 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5474 settings.always_allow_tool_actions = false;
5475 settings.tool_permissions.tools.insert(
5476 EditFileTool::NAME.into(),
5477 agent_settings::ToolRules {
5478 default_mode: settings::ToolPermissionMode::Confirm,
5479 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5480 always_deny: vec![],
5481 always_confirm: vec![],
5482 invalid_patterns: vec![],
5483 },
5484 );
5485 agent_settings::AgentSettings::override_global(settings, cx);
5486 });
5487
5488 let context_server_registry =
5489 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5490 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5491 let templates = crate::Templates::new();
5492 let thread = cx.new(|cx| {
5493 crate::Thread::new(
5494 project.clone(),
5495 cx.new(|_cx| prompt_store::ProjectContext::default()),
5496 context_server_registry,
5497 templates.clone(),
5498 None,
5499 cx,
5500 )
5501 });
5502
5503 #[allow(clippy::arc_with_non_send_sync)]
5504 let tool = Arc::new(crate::EditFileTool::new(
5505 project,
5506 thread.downgrade(),
5507 language_registry,
5508 templates,
5509 ));
5510 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5511
5512 let _task = cx.update(|cx| {
5513 tool.run(
5514 crate::EditFileToolInput {
5515 display_description: "Edit README".to_string(),
5516 path: "root/README.md".into(),
5517 mode: crate::EditFileMode::Edit,
5518 },
5519 event_stream,
5520 cx,
5521 )
5522 });
5523
5524 cx.run_until_parked();
5525
5526 let event = rx.try_next();
5527 assert!(
5528 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5529 "expected no authorization request for allowed .md file"
5530 );
5531}
5532
5533#[gpui::test]
5534async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5535 init_test(cx);
5536
5537 cx.update(|cx| {
5538 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5539 settings.tool_permissions.tools.insert(
5540 FetchTool::NAME.into(),
5541 agent_settings::ToolRules {
5542 default_mode: settings::ToolPermissionMode::Allow,
5543 always_allow: vec![],
5544 always_deny: vec![
5545 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5546 ],
5547 always_confirm: vec![],
5548 invalid_patterns: vec![],
5549 },
5550 );
5551 agent_settings::AgentSettings::override_global(settings, cx);
5552 });
5553
5554 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5555
5556 #[allow(clippy::arc_with_non_send_sync)]
5557 let tool = Arc::new(crate::FetchTool::new(http_client));
5558 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5559
5560 let input: crate::FetchToolInput =
5561 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5562
5563 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5564
5565 let result = task.await;
5566 assert!(result.is_err(), "expected fetch to be blocked");
5567 assert!(
5568 result.unwrap_err().to_string().contains("blocked"),
5569 "error should mention the fetch was blocked"
5570 );
5571}
5572
5573#[gpui::test]
5574async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5575 init_test(cx);
5576
5577 cx.update(|cx| {
5578 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5579 settings.always_allow_tool_actions = false;
5580 settings.tool_permissions.tools.insert(
5581 FetchTool::NAME.into(),
5582 agent_settings::ToolRules {
5583 default_mode: settings::ToolPermissionMode::Confirm,
5584 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5585 always_deny: vec![],
5586 always_confirm: vec![],
5587 invalid_patterns: vec![],
5588 },
5589 );
5590 agent_settings::AgentSettings::override_global(settings, cx);
5591 });
5592
5593 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5594
5595 #[allow(clippy::arc_with_non_send_sync)]
5596 let tool = Arc::new(crate::FetchTool::new(http_client));
5597 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5598
5599 let input: crate::FetchToolInput =
5600 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5601
5602 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5603
5604 cx.run_until_parked();
5605
5606 let event = rx.try_next();
5607 assert!(
5608 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5609 "expected no authorization request for allowed docs.rs URL"
5610 );
5611}
5612
5613#[gpui::test]
5614async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5615 init_test(cx);
5616 always_allow_tools(cx);
5617
5618 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5619 let fake_model = model.as_fake();
5620
5621 // Add a tool so we can simulate tool calls
5622 thread.update(cx, |thread, _cx| {
5623 thread.add_tool(EchoTool);
5624 });
5625
5626 // Start a turn by sending a message
5627 let mut events = thread
5628 .update(cx, |thread, cx| {
5629 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5630 })
5631 .unwrap();
5632 cx.run_until_parked();
5633
5634 // Simulate the model making a tool call
5635 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5636 LanguageModelToolUse {
5637 id: "tool_1".into(),
5638 name: "echo".into(),
5639 raw_input: r#"{"text": "hello"}"#.into(),
5640 input: json!({"text": "hello"}),
5641 is_input_complete: true,
5642 thought_signature: None,
5643 },
5644 ));
5645 fake_model
5646 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5647
5648 // Signal that a message is queued before ending the stream
5649 thread.update(cx, |thread, _cx| {
5650 thread.set_has_queued_message(true);
5651 });
5652
5653 // Now end the stream - tool will run, and the boundary check should see the queue
5654 fake_model.end_last_completion_stream();
5655
5656 // Collect all events until the turn stops
5657 let all_events = collect_events_until_stop(&mut events, cx).await;
5658
5659 // Verify we received the tool call event
5660 let tool_call_ids: Vec<_> = all_events
5661 .iter()
5662 .filter_map(|e| match e {
5663 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5664 _ => None,
5665 })
5666 .collect();
5667 assert_eq!(
5668 tool_call_ids,
5669 vec!["tool_1"],
5670 "Should have received a tool call event for our echo tool"
5671 );
5672
5673 // The turn should have stopped with EndTurn
5674 let stop_reasons = stop_events(all_events);
5675 assert_eq!(
5676 stop_reasons,
5677 vec![acp::StopReason::EndTurn],
5678 "Turn should have ended after tool completion due to queued message"
5679 );
5680
5681 // Verify the queued message flag is still set
5682 thread.update(cx, |thread, _cx| {
5683 assert!(
5684 thread.has_queued_message(),
5685 "Should still have queued message flag set"
5686 );
5687 });
5688
5689 // Thread should be idle now
5690 thread.update(cx, |thread, _cx| {
5691 assert!(
5692 thread.is_turn_complete(),
5693 "Thread should not be running after turn ends"
5694 );
5695 });
5696}