1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, UserMessageId,
4};
5use agent_client_protocol::{self as acp};
6use agent_settings::AgentProfileId;
7use anyhow::Result;
8use client::{Client, UserStore};
9use cloud_llm_client::CompletionIntent;
10use collections::IndexMap;
11use context_server::{ContextServer, ContextServerCommand, ContextServerId};
12use feature_flags::FeatureFlagAppExt as _;
13use fs::{FakeFs, Fs};
14use futures::{
15 FutureExt as _, StreamExt,
16 channel::{
17 mpsc::{self, UnboundedReceiver},
18 oneshot,
19 },
20 future::{Fuse, Shared},
21};
22use gpui::{
23 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
24 http_client::FakeHttpClient,
25};
26use indoc::indoc;
27use language_model::{
28 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
29 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
30 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
31 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
32};
33use pretty_assertions::assert_eq;
34use project::{
35 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
36};
37use prompt_store::ProjectContext;
38use reqwest_client::ReqwestClient;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use serde_json::json;
42use settings::{Settings, SettingsStore};
43use std::{
44 path::Path,
45 pin::Pin,
46 rc::Rc,
47 sync::{
48 Arc,
49 atomic::{AtomicBool, Ordering},
50 },
51 time::Duration,
52};
53use util::path;
54
55mod edit_file_thread_test;
56mod test_tools;
57use test_tools::*;
58
59fn init_test(cx: &mut TestAppContext) {
60 cx.update(|cx| {
61 let settings_store = SettingsStore::test(cx);
62 cx.set_global(settings_store);
63 });
64}
65
66struct FakeTerminalHandle {
67 killed: Arc<AtomicBool>,
68 stopped_by_user: Arc<AtomicBool>,
69 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
70 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
71 output: acp::TerminalOutputResponse,
72 id: acp::TerminalId,
73}
74
75impl FakeTerminalHandle {
76 fn new_never_exits(cx: &mut App) -> Self {
77 let killed = Arc::new(AtomicBool::new(false));
78 let stopped_by_user = Arc::new(AtomicBool::new(false));
79
80 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
81
82 let wait_for_exit = cx
83 .spawn(async move |_cx| {
84 // Wait for the exit signal (sent when kill() is called)
85 let _ = exit_receiver.await;
86 acp::TerminalExitStatus::new()
87 })
88 .shared();
89
90 Self {
91 killed,
92 stopped_by_user,
93 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
94 wait_for_exit,
95 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
96 id: acp::TerminalId::new("fake_terminal".to_string()),
97 }
98 }
99
100 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
101 let killed = Arc::new(AtomicBool::new(false));
102 let stopped_by_user = Arc::new(AtomicBool::new(false));
103 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
104
105 let wait_for_exit = cx
106 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
107 .shared();
108
109 Self {
110 killed,
111 stopped_by_user,
112 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
113 wait_for_exit,
114 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
115 id: acp::TerminalId::new("fake_terminal".to_string()),
116 }
117 }
118
119 fn was_killed(&self) -> bool {
120 self.killed.load(Ordering::SeqCst)
121 }
122
123 fn set_stopped_by_user(&self, stopped: bool) {
124 self.stopped_by_user.store(stopped, Ordering::SeqCst);
125 }
126
127 fn signal_exit(&self) {
128 if let Some(sender) = self.exit_sender.borrow_mut().take() {
129 let _ = sender.send(());
130 }
131 }
132}
133
134impl crate::TerminalHandle for FakeTerminalHandle {
135 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
136 Ok(self.id.clone())
137 }
138
139 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
140 Ok(self.output.clone())
141 }
142
143 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
144 Ok(self.wait_for_exit.clone())
145 }
146
147 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
148 self.killed.store(true, Ordering::SeqCst);
149 self.signal_exit();
150 Ok(())
151 }
152
153 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
154 Ok(self.stopped_by_user.load(Ordering::SeqCst))
155 }
156}
157
158struct FakeThreadEnvironment {
159 handle: Rc<FakeTerminalHandle>,
160}
161
162impl crate::ThreadEnvironment for FakeThreadEnvironment {
163 fn create_terminal(
164 &self,
165 _command: String,
166 _cwd: Option<std::path::PathBuf>,
167 _output_byte_limit: Option<u64>,
168 _cx: &mut AsyncApp,
169 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
170 Task::ready(Ok(self.handle.clone() as Rc<dyn crate::TerminalHandle>))
171 }
172}
173
174/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
175struct MultiTerminalEnvironment {
176 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
177}
178
179impl MultiTerminalEnvironment {
180 fn new() -> Self {
181 Self {
182 handles: std::cell::RefCell::new(Vec::new()),
183 }
184 }
185
186 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
187 self.handles.borrow().clone()
188 }
189}
190
191impl crate::ThreadEnvironment for MultiTerminalEnvironment {
192 fn create_terminal(
193 &self,
194 _command: String,
195 _cwd: Option<std::path::PathBuf>,
196 _output_byte_limit: Option<u64>,
197 cx: &mut AsyncApp,
198 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
199 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
200 self.handles.borrow_mut().push(handle.clone());
201 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
202 }
203}
204
205fn always_allow_tools(cx: &mut TestAppContext) {
206 cx.update(|cx| {
207 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
208 settings.always_allow_tool_actions = true;
209 agent_settings::AgentSettings::override_global(settings, cx);
210 });
211}
212
213#[gpui::test]
214async fn test_echo(cx: &mut TestAppContext) {
215 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
216 let fake_model = model.as_fake();
217
218 let events = thread
219 .update(cx, |thread, cx| {
220 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
221 })
222 .unwrap();
223 cx.run_until_parked();
224 fake_model.send_last_completion_stream_text_chunk("Hello");
225 fake_model
226 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
227 fake_model.end_last_completion_stream();
228
229 let events = events.collect().await;
230 thread.update(cx, |thread, _cx| {
231 assert_eq!(
232 thread.last_message().unwrap().to_markdown(),
233 indoc! {"
234 ## Assistant
235
236 Hello
237 "}
238 )
239 });
240 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
241}
242
243#[gpui::test]
244async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
245 init_test(cx);
246 always_allow_tools(cx);
247
248 let fs = FakeFs::new(cx.executor());
249 let project = Project::test(fs, [], cx).await;
250
251 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
252 let environment = Rc::new(FakeThreadEnvironment {
253 handle: handle.clone(),
254 });
255
256 #[allow(clippy::arc_with_non_send_sync)]
257 let tool = Arc::new(crate::TerminalTool::new(project, environment));
258 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
259
260 let task = cx.update(|cx| {
261 tool.run(
262 crate::TerminalToolInput {
263 command: "sleep 1000".to_string(),
264 cd: ".".to_string(),
265 timeout_ms: Some(5),
266 },
267 event_stream,
268 cx,
269 )
270 });
271
272 let update = rx.expect_update_fields().await;
273 assert!(
274 update.content.iter().any(|blocks| {
275 blocks
276 .iter()
277 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
278 }),
279 "expected tool call update to include terminal content"
280 );
281
282 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
283
284 let deadline = std::time::Instant::now() + Duration::from_millis(500);
285 loop {
286 if let Some(result) = task_future.as_mut().now_or_never() {
287 let result = result.expect("terminal tool task should complete");
288
289 assert!(
290 handle.was_killed(),
291 "expected terminal handle to be killed on timeout"
292 );
293 assert!(
294 result.contains("partial output"),
295 "expected result to include terminal output, got: {result}"
296 );
297 return;
298 }
299
300 if std::time::Instant::now() >= deadline {
301 panic!("timed out waiting for terminal tool task to complete");
302 }
303
304 cx.run_until_parked();
305 cx.background_executor.timer(Duration::from_millis(1)).await;
306 }
307}
308
309#[gpui::test]
310#[ignore]
311async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
312 init_test(cx);
313 always_allow_tools(cx);
314
315 let fs = FakeFs::new(cx.executor());
316 let project = Project::test(fs, [], cx).await;
317
318 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
319 let environment = Rc::new(FakeThreadEnvironment {
320 handle: handle.clone(),
321 });
322
323 #[allow(clippy::arc_with_non_send_sync)]
324 let tool = Arc::new(crate::TerminalTool::new(project, environment));
325 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
326
327 let _task = cx.update(|cx| {
328 tool.run(
329 crate::TerminalToolInput {
330 command: "sleep 1000".to_string(),
331 cd: ".".to_string(),
332 timeout_ms: None,
333 },
334 event_stream,
335 cx,
336 )
337 });
338
339 let update = rx.expect_update_fields().await;
340 assert!(
341 update.content.iter().any(|blocks| {
342 blocks
343 .iter()
344 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
345 }),
346 "expected tool call update to include terminal content"
347 );
348
349 cx.background_executor
350 .timer(Duration::from_millis(25))
351 .await;
352
353 assert!(
354 !handle.was_killed(),
355 "did not expect terminal handle to be killed without a timeout"
356 );
357}
358
359#[gpui::test]
360async fn test_thinking(cx: &mut TestAppContext) {
361 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
362 let fake_model = model.as_fake();
363
364 let events = thread
365 .update(cx, |thread, cx| {
366 thread.send(
367 UserMessageId::new(),
368 [indoc! {"
369 Testing:
370
371 Generate a thinking step where you just think the word 'Think',
372 and have your final answer be 'Hello'
373 "}],
374 cx,
375 )
376 })
377 .unwrap();
378 cx.run_until_parked();
379 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
380 text: "Think".to_string(),
381 signature: None,
382 });
383 fake_model.send_last_completion_stream_text_chunk("Hello");
384 fake_model
385 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
386 fake_model.end_last_completion_stream();
387
388 let events = events.collect().await;
389 thread.update(cx, |thread, _cx| {
390 assert_eq!(
391 thread.last_message().unwrap().to_markdown(),
392 indoc! {"
393 ## Assistant
394
395 <think>Think</think>
396 Hello
397 "}
398 )
399 });
400 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
401}
402
403#[gpui::test]
404async fn test_system_prompt(cx: &mut TestAppContext) {
405 let ThreadTest {
406 model,
407 thread,
408 project_context,
409 ..
410 } = setup(cx, TestModel::Fake).await;
411 let fake_model = model.as_fake();
412
413 project_context.update(cx, |project_context, _cx| {
414 project_context.shell = "test-shell".into()
415 });
416 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
417 thread
418 .update(cx, |thread, cx| {
419 thread.send(UserMessageId::new(), ["abc"], cx)
420 })
421 .unwrap();
422 cx.run_until_parked();
423 let mut pending_completions = fake_model.pending_completions();
424 assert_eq!(
425 pending_completions.len(),
426 1,
427 "unexpected pending completions: {:?}",
428 pending_completions
429 );
430
431 let pending_completion = pending_completions.pop().unwrap();
432 assert_eq!(pending_completion.messages[0].role, Role::System);
433
434 let system_message = &pending_completion.messages[0];
435 let system_prompt = system_message.content[0].to_str().unwrap();
436 assert!(
437 system_prompt.contains("test-shell"),
438 "unexpected system message: {:?}",
439 system_message
440 );
441 assert!(
442 system_prompt.contains("## Fixing Diagnostics"),
443 "unexpected system message: {:?}",
444 system_message
445 );
446}
447
448#[gpui::test]
449async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
450 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
451 let fake_model = model.as_fake();
452
453 thread
454 .update(cx, |thread, cx| {
455 thread.send(UserMessageId::new(), ["abc"], cx)
456 })
457 .unwrap();
458 cx.run_until_parked();
459 let mut pending_completions = fake_model.pending_completions();
460 assert_eq!(
461 pending_completions.len(),
462 1,
463 "unexpected pending completions: {:?}",
464 pending_completions
465 );
466
467 let pending_completion = pending_completions.pop().unwrap();
468 assert_eq!(pending_completion.messages[0].role, Role::System);
469
470 let system_message = &pending_completion.messages[0];
471 let system_prompt = system_message.content[0].to_str().unwrap();
472 assert!(
473 !system_prompt.contains("## Tool Use"),
474 "unexpected system message: {:?}",
475 system_message
476 );
477 assert!(
478 !system_prompt.contains("## Fixing Diagnostics"),
479 "unexpected system message: {:?}",
480 system_message
481 );
482}
483
484#[gpui::test]
485async fn test_prompt_caching(cx: &mut TestAppContext) {
486 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
487 let fake_model = model.as_fake();
488
489 // Send initial user message and verify it's cached
490 thread
491 .update(cx, |thread, cx| {
492 thread.send(UserMessageId::new(), ["Message 1"], cx)
493 })
494 .unwrap();
495 cx.run_until_parked();
496
497 let completion = fake_model.pending_completions().pop().unwrap();
498 assert_eq!(
499 completion.messages[1..],
500 vec![LanguageModelRequestMessage {
501 role: Role::User,
502 content: vec!["Message 1".into()],
503 cache: true,
504 reasoning_details: None,
505 }]
506 );
507 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
508 "Response to Message 1".into(),
509 ));
510 fake_model.end_last_completion_stream();
511 cx.run_until_parked();
512
513 // Send another user message and verify only the latest is cached
514 thread
515 .update(cx, |thread, cx| {
516 thread.send(UserMessageId::new(), ["Message 2"], cx)
517 })
518 .unwrap();
519 cx.run_until_parked();
520
521 let completion = fake_model.pending_completions().pop().unwrap();
522 assert_eq!(
523 completion.messages[1..],
524 vec![
525 LanguageModelRequestMessage {
526 role: Role::User,
527 content: vec!["Message 1".into()],
528 cache: false,
529 reasoning_details: None,
530 },
531 LanguageModelRequestMessage {
532 role: Role::Assistant,
533 content: vec!["Response to Message 1".into()],
534 cache: false,
535 reasoning_details: None,
536 },
537 LanguageModelRequestMessage {
538 role: Role::User,
539 content: vec!["Message 2".into()],
540 cache: true,
541 reasoning_details: None,
542 }
543 ]
544 );
545 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
546 "Response to Message 2".into(),
547 ));
548 fake_model.end_last_completion_stream();
549 cx.run_until_parked();
550
551 // Simulate a tool call and verify that the latest tool result is cached
552 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
553 thread
554 .update(cx, |thread, cx| {
555 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
556 })
557 .unwrap();
558 cx.run_until_parked();
559
560 let tool_use = LanguageModelToolUse {
561 id: "tool_1".into(),
562 name: EchoTool::name().into(),
563 raw_input: json!({"text": "test"}).to_string(),
564 input: json!({"text": "test"}),
565 is_input_complete: true,
566 thought_signature: None,
567 };
568 fake_model
569 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
570 fake_model.end_last_completion_stream();
571 cx.run_until_parked();
572
573 let completion = fake_model.pending_completions().pop().unwrap();
574 let tool_result = LanguageModelToolResult {
575 tool_use_id: "tool_1".into(),
576 tool_name: EchoTool::name().into(),
577 is_error: false,
578 content: "test".into(),
579 output: Some("test".into()),
580 };
581 assert_eq!(
582 completion.messages[1..],
583 vec![
584 LanguageModelRequestMessage {
585 role: Role::User,
586 content: vec!["Message 1".into()],
587 cache: false,
588 reasoning_details: None,
589 },
590 LanguageModelRequestMessage {
591 role: Role::Assistant,
592 content: vec!["Response to Message 1".into()],
593 cache: false,
594 reasoning_details: None,
595 },
596 LanguageModelRequestMessage {
597 role: Role::User,
598 content: vec!["Message 2".into()],
599 cache: false,
600 reasoning_details: None,
601 },
602 LanguageModelRequestMessage {
603 role: Role::Assistant,
604 content: vec!["Response to Message 2".into()],
605 cache: false,
606 reasoning_details: None,
607 },
608 LanguageModelRequestMessage {
609 role: Role::User,
610 content: vec!["Use the echo tool".into()],
611 cache: false,
612 reasoning_details: None,
613 },
614 LanguageModelRequestMessage {
615 role: Role::Assistant,
616 content: vec![MessageContent::ToolUse(tool_use)],
617 cache: false,
618 reasoning_details: None,
619 },
620 LanguageModelRequestMessage {
621 role: Role::User,
622 content: vec![MessageContent::ToolResult(tool_result)],
623 cache: true,
624 reasoning_details: None,
625 }
626 ]
627 );
628}
629
630#[gpui::test]
631#[cfg_attr(not(feature = "e2e"), ignore)]
632async fn test_basic_tool_calls(cx: &mut TestAppContext) {
633 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
634
635 // Test a tool call that's likely to complete *before* streaming stops.
636 let events = thread
637 .update(cx, |thread, cx| {
638 thread.add_tool(EchoTool);
639 thread.send(
640 UserMessageId::new(),
641 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
642 cx,
643 )
644 })
645 .unwrap()
646 .collect()
647 .await;
648 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
649
650 // Test a tool calls that's likely to complete *after* streaming stops.
651 let events = thread
652 .update(cx, |thread, cx| {
653 thread.remove_tool(&EchoTool::name());
654 thread.add_tool(DelayTool);
655 thread.send(
656 UserMessageId::new(),
657 [
658 "Now call the delay tool with 200ms.",
659 "When the timer goes off, then you echo the output of the tool.",
660 ],
661 cx,
662 )
663 })
664 .unwrap()
665 .collect()
666 .await;
667 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
668 thread.update(cx, |thread, _cx| {
669 assert!(
670 thread
671 .last_message()
672 .unwrap()
673 .as_agent_message()
674 .unwrap()
675 .content
676 .iter()
677 .any(|content| {
678 if let AgentMessageContent::Text(text) = content {
679 text.contains("Ding")
680 } else {
681 false
682 }
683 }),
684 "{}",
685 thread.to_markdown()
686 );
687 });
688}
689
690#[gpui::test]
691#[cfg_attr(not(feature = "e2e"), ignore)]
692async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
693 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
694
695 // Test a tool call that's likely to complete *before* streaming stops.
696 let mut events = thread
697 .update(cx, |thread, cx| {
698 thread.add_tool(WordListTool);
699 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
700 })
701 .unwrap();
702
703 let mut saw_partial_tool_use = false;
704 while let Some(event) = events.next().await {
705 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
706 thread.update(cx, |thread, _cx| {
707 // Look for a tool use in the thread's last message
708 let message = thread.last_message().unwrap();
709 let agent_message = message.as_agent_message().unwrap();
710 let last_content = agent_message.content.last().unwrap();
711 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
712 assert_eq!(last_tool_use.name.as_ref(), "word_list");
713 if tool_call.status == acp::ToolCallStatus::Pending {
714 if !last_tool_use.is_input_complete
715 && last_tool_use.input.get("g").is_none()
716 {
717 saw_partial_tool_use = true;
718 }
719 } else {
720 last_tool_use
721 .input
722 .get("a")
723 .expect("'a' has streamed because input is now complete");
724 last_tool_use
725 .input
726 .get("g")
727 .expect("'g' has streamed because input is now complete");
728 }
729 } else {
730 panic!("last content should be a tool use");
731 }
732 });
733 }
734 }
735
736 assert!(
737 saw_partial_tool_use,
738 "should see at least one partially streamed tool use in the history"
739 );
740}
741
742#[gpui::test]
743async fn test_tool_authorization(cx: &mut TestAppContext) {
744 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
745 let fake_model = model.as_fake();
746
747 let mut events = thread
748 .update(cx, |thread, cx| {
749 thread.add_tool(ToolRequiringPermission);
750 thread.send(UserMessageId::new(), ["abc"], cx)
751 })
752 .unwrap();
753 cx.run_until_parked();
754 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
755 LanguageModelToolUse {
756 id: "tool_id_1".into(),
757 name: ToolRequiringPermission::name().into(),
758 raw_input: "{}".into(),
759 input: json!({}),
760 is_input_complete: true,
761 thought_signature: None,
762 },
763 ));
764 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
765 LanguageModelToolUse {
766 id: "tool_id_2".into(),
767 name: ToolRequiringPermission::name().into(),
768 raw_input: "{}".into(),
769 input: json!({}),
770 is_input_complete: true,
771 thought_signature: None,
772 },
773 ));
774 fake_model.end_last_completion_stream();
775 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
776 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
777
778 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
779 tool_call_auth_1
780 .response
781 .send(acp::PermissionOptionId::new("allow"))
782 .unwrap();
783 cx.run_until_parked();
784
785 // Reject the second - send "deny" option_id directly since Deny is now a button
786 tool_call_auth_2
787 .response
788 .send(acp::PermissionOptionId::new("deny"))
789 .unwrap();
790 cx.run_until_parked();
791
792 let completion = fake_model.pending_completions().pop().unwrap();
793 let message = completion.messages.last().unwrap();
794 assert_eq!(
795 message.content,
796 vec![
797 language_model::MessageContent::ToolResult(LanguageModelToolResult {
798 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
799 tool_name: ToolRequiringPermission::name().into(),
800 is_error: false,
801 content: "Allowed".into(),
802 output: Some("Allowed".into())
803 }),
804 language_model::MessageContent::ToolResult(LanguageModelToolResult {
805 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
806 tool_name: ToolRequiringPermission::name().into(),
807 is_error: true,
808 content: "Permission to run tool denied by user".into(),
809 output: Some("Permission to run tool denied by user".into())
810 })
811 ]
812 );
813
814 // Simulate yet another tool call.
815 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
816 LanguageModelToolUse {
817 id: "tool_id_3".into(),
818 name: ToolRequiringPermission::name().into(),
819 raw_input: "{}".into(),
820 input: json!({}),
821 is_input_complete: true,
822 thought_signature: None,
823 },
824 ));
825 fake_model.end_last_completion_stream();
826
827 // Respond by always allowing tools - send transformed option_id
828 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
829 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
830 tool_call_auth_3
831 .response
832 .send(acp::PermissionOptionId::new(
833 "always_allow:tool_requiring_permission",
834 ))
835 .unwrap();
836 cx.run_until_parked();
837 let completion = fake_model.pending_completions().pop().unwrap();
838 let message = completion.messages.last().unwrap();
839 assert_eq!(
840 message.content,
841 vec![language_model::MessageContent::ToolResult(
842 LanguageModelToolResult {
843 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
844 tool_name: ToolRequiringPermission::name().into(),
845 is_error: false,
846 content: "Allowed".into(),
847 output: Some("Allowed".into())
848 }
849 )]
850 );
851
852 // Simulate a final tool call, ensuring we don't trigger authorization.
853 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
854 LanguageModelToolUse {
855 id: "tool_id_4".into(),
856 name: ToolRequiringPermission::name().into(),
857 raw_input: "{}".into(),
858 input: json!({}),
859 is_input_complete: true,
860 thought_signature: None,
861 },
862 ));
863 fake_model.end_last_completion_stream();
864 cx.run_until_parked();
865 let completion = fake_model.pending_completions().pop().unwrap();
866 let message = completion.messages.last().unwrap();
867 assert_eq!(
868 message.content,
869 vec![language_model::MessageContent::ToolResult(
870 LanguageModelToolResult {
871 tool_use_id: "tool_id_4".into(),
872 tool_name: ToolRequiringPermission::name().into(),
873 is_error: false,
874 content: "Allowed".into(),
875 output: Some("Allowed".into())
876 }
877 )]
878 );
879}
880
881#[gpui::test]
882async fn test_tool_hallucination(cx: &mut TestAppContext) {
883 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
884 let fake_model = model.as_fake();
885
886 let mut events = thread
887 .update(cx, |thread, cx| {
888 thread.send(UserMessageId::new(), ["abc"], cx)
889 })
890 .unwrap();
891 cx.run_until_parked();
892 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
893 LanguageModelToolUse {
894 id: "tool_id_1".into(),
895 name: "nonexistent_tool".into(),
896 raw_input: "{}".into(),
897 input: json!({}),
898 is_input_complete: true,
899 thought_signature: None,
900 },
901 ));
902 fake_model.end_last_completion_stream();
903
904 let tool_call = expect_tool_call(&mut events).await;
905 assert_eq!(tool_call.title, "nonexistent_tool");
906 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
907 let update = expect_tool_call_update_fields(&mut events).await;
908 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
909}
910
911async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
912 let event = events
913 .next()
914 .await
915 .expect("no tool call authorization event received")
916 .unwrap();
917 match event {
918 ThreadEvent::ToolCall(tool_call) => tool_call,
919 event => {
920 panic!("Unexpected event {event:?}");
921 }
922 }
923}
924
925async fn expect_tool_call_update_fields(
926 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
927) -> acp::ToolCallUpdate {
928 let event = events
929 .next()
930 .await
931 .expect("no tool call authorization event received")
932 .unwrap();
933 match event {
934 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
935 event => {
936 panic!("Unexpected event {event:?}");
937 }
938 }
939}
940
941async fn next_tool_call_authorization(
942 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
943) -> ToolCallAuthorization {
944 loop {
945 let event = events
946 .next()
947 .await
948 .expect("no tool call authorization event received")
949 .unwrap();
950 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
951 let permission_kinds = tool_call_authorization
952 .options
953 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
954 .map(|option| option.kind);
955 let allow_once = tool_call_authorization
956 .options
957 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
958 .map(|option| option.kind);
959
960 assert_eq!(
961 permission_kinds,
962 Some(acp::PermissionOptionKind::AllowAlways)
963 );
964 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
965 return tool_call_authorization;
966 }
967 }
968}
969
970#[test]
971fn test_permission_options_terminal_with_pattern() {
972 let permission_options =
973 ToolPermissionContext::new("terminal", "cargo build --release").build_permission_options();
974
975 let PermissionOptions::Dropdown(choices) = permission_options else {
976 panic!("Expected dropdown permission options");
977 };
978
979 assert_eq!(choices.len(), 3);
980 let labels: Vec<&str> = choices
981 .iter()
982 .map(|choice| choice.allow.name.as_ref())
983 .collect();
984 assert!(labels.contains(&"Always for terminal"));
985 assert!(labels.contains(&"Always for `cargo` commands"));
986 assert!(labels.contains(&"Only this time"));
987}
988
989#[test]
990fn test_permission_options_edit_file_with_path_pattern() {
991 let permission_options =
992 ToolPermissionContext::new("edit_file", "src/main.rs").build_permission_options();
993
994 let PermissionOptions::Dropdown(choices) = permission_options else {
995 panic!("Expected dropdown permission options");
996 };
997
998 let labels: Vec<&str> = choices
999 .iter()
1000 .map(|choice| choice.allow.name.as_ref())
1001 .collect();
1002 assert!(labels.contains(&"Always for edit file"));
1003 assert!(labels.contains(&"Always for `src/`"));
1004}
1005
1006#[test]
1007fn test_permission_options_fetch_with_domain_pattern() {
1008 let permission_options =
1009 ToolPermissionContext::new("fetch", "https://docs.rs/gpui").build_permission_options();
1010
1011 let PermissionOptions::Dropdown(choices) = permission_options else {
1012 panic!("Expected dropdown permission options");
1013 };
1014
1015 let labels: Vec<&str> = choices
1016 .iter()
1017 .map(|choice| choice.allow.name.as_ref())
1018 .collect();
1019 assert!(labels.contains(&"Always for fetch"));
1020 assert!(labels.contains(&"Always for `docs.rs`"));
1021}
1022
1023#[test]
1024fn test_permission_options_without_pattern() {
1025 let permission_options = ToolPermissionContext::new("terminal", "./deploy.sh --production")
1026 .build_permission_options();
1027
1028 let PermissionOptions::Dropdown(choices) = permission_options else {
1029 panic!("Expected dropdown permission options");
1030 };
1031
1032 assert_eq!(choices.len(), 2);
1033 let labels: Vec<&str> = choices
1034 .iter()
1035 .map(|choice| choice.allow.name.as_ref())
1036 .collect();
1037 assert!(labels.contains(&"Always for terminal"));
1038 assert!(labels.contains(&"Only this time"));
1039 assert!(!labels.iter().any(|label| label.contains("commands")));
1040}
1041
1042#[test]
1043fn test_permission_option_ids_for_terminal() {
1044 let permission_options =
1045 ToolPermissionContext::new("terminal", "cargo build --release").build_permission_options();
1046
1047 let PermissionOptions::Dropdown(choices) = permission_options else {
1048 panic!("Expected dropdown permission options");
1049 };
1050
1051 let allow_ids: Vec<String> = choices
1052 .iter()
1053 .map(|choice| choice.allow.option_id.0.to_string())
1054 .collect();
1055 let deny_ids: Vec<String> = choices
1056 .iter()
1057 .map(|choice| choice.deny.option_id.0.to_string())
1058 .collect();
1059
1060 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1061 assert!(allow_ids.contains(&"allow".to_string()));
1062 assert!(
1063 allow_ids
1064 .iter()
1065 .any(|id| id.starts_with("always_allow_pattern:terminal:")),
1066 "Missing allow pattern option"
1067 );
1068
1069 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1070 assert!(deny_ids.contains(&"deny".to_string()));
1071 assert!(
1072 deny_ids
1073 .iter()
1074 .any(|id| id.starts_with("always_deny_pattern:terminal:")),
1075 "Missing deny pattern option"
1076 );
1077}
1078
1079#[gpui::test]
1080#[cfg_attr(not(feature = "e2e"), ignore)]
1081async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1082 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1083
1084 // Test concurrent tool calls with different delay times
1085 let events = thread
1086 .update(cx, |thread, cx| {
1087 thread.add_tool(DelayTool);
1088 thread.send(
1089 UserMessageId::new(),
1090 [
1091 "Call the delay tool twice in the same message.",
1092 "Once with 100ms. Once with 300ms.",
1093 "When both timers are complete, describe the outputs.",
1094 ],
1095 cx,
1096 )
1097 })
1098 .unwrap()
1099 .collect()
1100 .await;
1101
1102 let stop_reasons = stop_events(events);
1103 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1104
1105 thread.update(cx, |thread, _cx| {
1106 let last_message = thread.last_message().unwrap();
1107 let agent_message = last_message.as_agent_message().unwrap();
1108 let text = agent_message
1109 .content
1110 .iter()
1111 .filter_map(|content| {
1112 if let AgentMessageContent::Text(text) = content {
1113 Some(text.as_str())
1114 } else {
1115 None
1116 }
1117 })
1118 .collect::<String>();
1119
1120 assert!(text.contains("Ding"));
1121 });
1122}
1123
1124#[gpui::test]
1125async fn test_profiles(cx: &mut TestAppContext) {
1126 let ThreadTest {
1127 model, thread, fs, ..
1128 } = setup(cx, TestModel::Fake).await;
1129 let fake_model = model.as_fake();
1130
1131 thread.update(cx, |thread, _cx| {
1132 thread.add_tool(DelayTool);
1133 thread.add_tool(EchoTool);
1134 thread.add_tool(InfiniteTool);
1135 });
1136
1137 // Override profiles and wait for settings to be loaded.
1138 fs.insert_file(
1139 paths::settings_file(),
1140 json!({
1141 "agent": {
1142 "profiles": {
1143 "test-1": {
1144 "name": "Test Profile 1",
1145 "tools": {
1146 EchoTool::name(): true,
1147 DelayTool::name(): true,
1148 }
1149 },
1150 "test-2": {
1151 "name": "Test Profile 2",
1152 "tools": {
1153 InfiniteTool::name(): true,
1154 }
1155 }
1156 }
1157 }
1158 })
1159 .to_string()
1160 .into_bytes(),
1161 )
1162 .await;
1163 cx.run_until_parked();
1164
1165 // Test that test-1 profile (default) has echo and delay tools
1166 thread
1167 .update(cx, |thread, cx| {
1168 thread.set_profile(AgentProfileId("test-1".into()), cx);
1169 thread.send(UserMessageId::new(), ["test"], cx)
1170 })
1171 .unwrap();
1172 cx.run_until_parked();
1173
1174 let mut pending_completions = fake_model.pending_completions();
1175 assert_eq!(pending_completions.len(), 1);
1176 let completion = pending_completions.pop().unwrap();
1177 let tool_names: Vec<String> = completion
1178 .tools
1179 .iter()
1180 .map(|tool| tool.name.clone())
1181 .collect();
1182 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
1183 fake_model.end_last_completion_stream();
1184
1185 // Switch to test-2 profile, and verify that it has only the infinite tool.
1186 thread
1187 .update(cx, |thread, cx| {
1188 thread.set_profile(AgentProfileId("test-2".into()), cx);
1189 thread.send(UserMessageId::new(), ["test2"], cx)
1190 })
1191 .unwrap();
1192 cx.run_until_parked();
1193 let mut pending_completions = fake_model.pending_completions();
1194 assert_eq!(pending_completions.len(), 1);
1195 let completion = pending_completions.pop().unwrap();
1196 let tool_names: Vec<String> = completion
1197 .tools
1198 .iter()
1199 .map(|tool| tool.name.clone())
1200 .collect();
1201 assert_eq!(tool_names, vec![InfiniteTool::name()]);
1202}
1203
1204#[gpui::test]
1205async fn test_mcp_tools(cx: &mut TestAppContext) {
1206 let ThreadTest {
1207 model,
1208 thread,
1209 context_server_store,
1210 fs,
1211 ..
1212 } = setup(cx, TestModel::Fake).await;
1213 let fake_model = model.as_fake();
1214
1215 // Override profiles and wait for settings to be loaded.
1216 fs.insert_file(
1217 paths::settings_file(),
1218 json!({
1219 "agent": {
1220 "always_allow_tool_actions": true,
1221 "profiles": {
1222 "test": {
1223 "name": "Test Profile",
1224 "enable_all_context_servers": true,
1225 "tools": {
1226 EchoTool::name(): true,
1227 }
1228 },
1229 }
1230 }
1231 })
1232 .to_string()
1233 .into_bytes(),
1234 )
1235 .await;
1236 cx.run_until_parked();
1237 thread.update(cx, |thread, cx| {
1238 thread.set_profile(AgentProfileId("test".into()), cx)
1239 });
1240
1241 let mut mcp_tool_calls = setup_context_server(
1242 "test_server",
1243 vec![context_server::types::Tool {
1244 name: "echo".into(),
1245 description: None,
1246 input_schema: serde_json::to_value(EchoTool::input_schema(
1247 LanguageModelToolSchemaFormat::JsonSchema,
1248 ))
1249 .unwrap(),
1250 output_schema: None,
1251 annotations: None,
1252 }],
1253 &context_server_store,
1254 cx,
1255 );
1256
1257 let events = thread.update(cx, |thread, cx| {
1258 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1259 });
1260 cx.run_until_parked();
1261
1262 // Simulate the model calling the MCP tool.
1263 let completion = fake_model.pending_completions().pop().unwrap();
1264 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1265 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1266 LanguageModelToolUse {
1267 id: "tool_1".into(),
1268 name: "echo".into(),
1269 raw_input: json!({"text": "test"}).to_string(),
1270 input: json!({"text": "test"}),
1271 is_input_complete: true,
1272 thought_signature: None,
1273 },
1274 ));
1275 fake_model.end_last_completion_stream();
1276 cx.run_until_parked();
1277
1278 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1279 assert_eq!(tool_call_params.name, "echo");
1280 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1281 tool_call_response
1282 .send(context_server::types::CallToolResponse {
1283 content: vec![context_server::types::ToolResponseContent::Text {
1284 text: "test".into(),
1285 }],
1286 is_error: None,
1287 meta: None,
1288 structured_content: None,
1289 })
1290 .unwrap();
1291 cx.run_until_parked();
1292
1293 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1294 fake_model.send_last_completion_stream_text_chunk("Done!");
1295 fake_model.end_last_completion_stream();
1296 events.collect::<Vec<_>>().await;
1297
1298 // Send again after adding the echo tool, ensuring the name collision is resolved.
1299 let events = thread.update(cx, |thread, cx| {
1300 thread.add_tool(EchoTool);
1301 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1302 });
1303 cx.run_until_parked();
1304 let completion = fake_model.pending_completions().pop().unwrap();
1305 assert_eq!(
1306 tool_names_for_completion(&completion),
1307 vec!["echo", "test_server_echo"]
1308 );
1309 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1310 LanguageModelToolUse {
1311 id: "tool_2".into(),
1312 name: "test_server_echo".into(),
1313 raw_input: json!({"text": "mcp"}).to_string(),
1314 input: json!({"text": "mcp"}),
1315 is_input_complete: true,
1316 thought_signature: None,
1317 },
1318 ));
1319 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1320 LanguageModelToolUse {
1321 id: "tool_3".into(),
1322 name: "echo".into(),
1323 raw_input: json!({"text": "native"}).to_string(),
1324 input: json!({"text": "native"}),
1325 is_input_complete: true,
1326 thought_signature: None,
1327 },
1328 ));
1329 fake_model.end_last_completion_stream();
1330 cx.run_until_parked();
1331
1332 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1333 assert_eq!(tool_call_params.name, "echo");
1334 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1335 tool_call_response
1336 .send(context_server::types::CallToolResponse {
1337 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1338 is_error: None,
1339 meta: None,
1340 structured_content: None,
1341 })
1342 .unwrap();
1343 cx.run_until_parked();
1344
1345 // Ensure the tool results were inserted with the correct names.
1346 let completion = fake_model.pending_completions().pop().unwrap();
1347 assert_eq!(
1348 completion.messages.last().unwrap().content,
1349 vec![
1350 MessageContent::ToolResult(LanguageModelToolResult {
1351 tool_use_id: "tool_3".into(),
1352 tool_name: "echo".into(),
1353 is_error: false,
1354 content: "native".into(),
1355 output: Some("native".into()),
1356 },),
1357 MessageContent::ToolResult(LanguageModelToolResult {
1358 tool_use_id: "tool_2".into(),
1359 tool_name: "test_server_echo".into(),
1360 is_error: false,
1361 content: "mcp".into(),
1362 output: Some("mcp".into()),
1363 },),
1364 ]
1365 );
1366 fake_model.end_last_completion_stream();
1367 events.collect::<Vec<_>>().await;
1368}
1369
1370#[gpui::test]
1371async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1372 let ThreadTest {
1373 model,
1374 thread,
1375 context_server_store,
1376 fs,
1377 ..
1378 } = setup(cx, TestModel::Fake).await;
1379 let fake_model = model.as_fake();
1380
1381 // Set up a profile with all tools enabled
1382 fs.insert_file(
1383 paths::settings_file(),
1384 json!({
1385 "agent": {
1386 "profiles": {
1387 "test": {
1388 "name": "Test Profile",
1389 "enable_all_context_servers": true,
1390 "tools": {
1391 EchoTool::name(): true,
1392 DelayTool::name(): true,
1393 WordListTool::name(): true,
1394 ToolRequiringPermission::name(): true,
1395 InfiniteTool::name(): true,
1396 }
1397 },
1398 }
1399 }
1400 })
1401 .to_string()
1402 .into_bytes(),
1403 )
1404 .await;
1405 cx.run_until_parked();
1406
1407 thread.update(cx, |thread, cx| {
1408 thread.set_profile(AgentProfileId("test".into()), cx);
1409 thread.add_tool(EchoTool);
1410 thread.add_tool(DelayTool);
1411 thread.add_tool(WordListTool);
1412 thread.add_tool(ToolRequiringPermission);
1413 thread.add_tool(InfiniteTool);
1414 });
1415
1416 // Set up multiple context servers with some overlapping tool names
1417 let _server1_calls = setup_context_server(
1418 "xxx",
1419 vec![
1420 context_server::types::Tool {
1421 name: "echo".into(), // Conflicts with native EchoTool
1422 description: None,
1423 input_schema: serde_json::to_value(EchoTool::input_schema(
1424 LanguageModelToolSchemaFormat::JsonSchema,
1425 ))
1426 .unwrap(),
1427 output_schema: None,
1428 annotations: None,
1429 },
1430 context_server::types::Tool {
1431 name: "unique_tool_1".into(),
1432 description: None,
1433 input_schema: json!({"type": "object", "properties": {}}),
1434 output_schema: None,
1435 annotations: None,
1436 },
1437 ],
1438 &context_server_store,
1439 cx,
1440 );
1441
1442 let _server2_calls = setup_context_server(
1443 "yyy",
1444 vec![
1445 context_server::types::Tool {
1446 name: "echo".into(), // Also conflicts with native EchoTool
1447 description: None,
1448 input_schema: serde_json::to_value(EchoTool::input_schema(
1449 LanguageModelToolSchemaFormat::JsonSchema,
1450 ))
1451 .unwrap(),
1452 output_schema: None,
1453 annotations: None,
1454 },
1455 context_server::types::Tool {
1456 name: "unique_tool_2".into(),
1457 description: None,
1458 input_schema: json!({"type": "object", "properties": {}}),
1459 output_schema: None,
1460 annotations: None,
1461 },
1462 context_server::types::Tool {
1463 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1464 description: None,
1465 input_schema: json!({"type": "object", "properties": {}}),
1466 output_schema: None,
1467 annotations: None,
1468 },
1469 context_server::types::Tool {
1470 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1471 description: None,
1472 input_schema: json!({"type": "object", "properties": {}}),
1473 output_schema: None,
1474 annotations: None,
1475 },
1476 ],
1477 &context_server_store,
1478 cx,
1479 );
1480 let _server3_calls = setup_context_server(
1481 "zzz",
1482 vec![
1483 context_server::types::Tool {
1484 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1485 description: None,
1486 input_schema: json!({"type": "object", "properties": {}}),
1487 output_schema: None,
1488 annotations: None,
1489 },
1490 context_server::types::Tool {
1491 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1492 description: None,
1493 input_schema: json!({"type": "object", "properties": {}}),
1494 output_schema: None,
1495 annotations: None,
1496 },
1497 context_server::types::Tool {
1498 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1499 description: None,
1500 input_schema: json!({"type": "object", "properties": {}}),
1501 output_schema: None,
1502 annotations: None,
1503 },
1504 ],
1505 &context_server_store,
1506 cx,
1507 );
1508
1509 thread
1510 .update(cx, |thread, cx| {
1511 thread.send(UserMessageId::new(), ["Go"], cx)
1512 })
1513 .unwrap();
1514 cx.run_until_parked();
1515 let completion = fake_model.pending_completions().pop().unwrap();
1516 assert_eq!(
1517 tool_names_for_completion(&completion),
1518 vec![
1519 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1520 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1521 "delay",
1522 "echo",
1523 "infinite",
1524 "tool_requiring_permission",
1525 "unique_tool_1",
1526 "unique_tool_2",
1527 "word_list",
1528 "xxx_echo",
1529 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1530 "yyy_echo",
1531 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1532 ]
1533 );
1534}
1535
1536#[gpui::test]
1537#[cfg_attr(not(feature = "e2e"), ignore)]
1538async fn test_cancellation(cx: &mut TestAppContext) {
1539 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1540
1541 let mut events = thread
1542 .update(cx, |thread, cx| {
1543 thread.add_tool(InfiniteTool);
1544 thread.add_tool(EchoTool);
1545 thread.send(
1546 UserMessageId::new(),
1547 ["Call the echo tool, then call the infinite tool, then explain their output"],
1548 cx,
1549 )
1550 })
1551 .unwrap();
1552
1553 // Wait until both tools are called.
1554 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1555 let mut echo_id = None;
1556 let mut echo_completed = false;
1557 while let Some(event) = events.next().await {
1558 match event.unwrap() {
1559 ThreadEvent::ToolCall(tool_call) => {
1560 assert_eq!(tool_call.title, expected_tools.remove(0));
1561 if tool_call.title == "Echo" {
1562 echo_id = Some(tool_call.tool_call_id);
1563 }
1564 }
1565 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1566 acp::ToolCallUpdate {
1567 tool_call_id,
1568 fields:
1569 acp::ToolCallUpdateFields {
1570 status: Some(acp::ToolCallStatus::Completed),
1571 ..
1572 },
1573 ..
1574 },
1575 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1576 echo_completed = true;
1577 }
1578 _ => {}
1579 }
1580
1581 if expected_tools.is_empty() && echo_completed {
1582 break;
1583 }
1584 }
1585
1586 // Cancel the current send and ensure that the event stream is closed, even
1587 // if one of the tools is still running.
1588 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1589 let events = events.collect::<Vec<_>>().await;
1590 let last_event = events.last();
1591 assert!(
1592 matches!(
1593 last_event,
1594 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1595 ),
1596 "unexpected event {last_event:?}"
1597 );
1598
1599 // Ensure we can still send a new message after cancellation.
1600 let events = thread
1601 .update(cx, |thread, cx| {
1602 thread.send(
1603 UserMessageId::new(),
1604 ["Testing: reply with 'Hello' then stop."],
1605 cx,
1606 )
1607 })
1608 .unwrap()
1609 .collect::<Vec<_>>()
1610 .await;
1611 thread.update(cx, |thread, _cx| {
1612 let message = thread.last_message().unwrap();
1613 let agent_message = message.as_agent_message().unwrap();
1614 assert_eq!(
1615 agent_message.content,
1616 vec![AgentMessageContent::Text("Hello".to_string())]
1617 );
1618 });
1619 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1620}
1621
1622#[gpui::test]
1623async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1624 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1625 always_allow_tools(cx);
1626 let fake_model = model.as_fake();
1627
1628 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1629 let environment = Rc::new(FakeThreadEnvironment {
1630 handle: handle.clone(),
1631 });
1632
1633 let mut events = thread
1634 .update(cx, |thread, cx| {
1635 thread.add_tool(crate::TerminalTool::new(
1636 thread.project().clone(),
1637 environment,
1638 ));
1639 thread.send(UserMessageId::new(), ["run a command"], cx)
1640 })
1641 .unwrap();
1642
1643 cx.run_until_parked();
1644
1645 // Simulate the model calling the terminal tool
1646 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1647 LanguageModelToolUse {
1648 id: "terminal_tool_1".into(),
1649 name: "terminal".into(),
1650 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1651 input: json!({"command": "sleep 1000", "cd": "."}),
1652 is_input_complete: true,
1653 thought_signature: None,
1654 },
1655 ));
1656 fake_model.end_last_completion_stream();
1657
1658 // Wait for the terminal tool to start running
1659 wait_for_terminal_tool_started(&mut events, cx).await;
1660
1661 // Cancel the thread while the terminal is running
1662 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1663
1664 // Collect remaining events, driving the executor to let cancellation complete
1665 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1666
1667 // Verify the terminal was killed
1668 assert!(
1669 handle.was_killed(),
1670 "expected terminal handle to be killed on cancellation"
1671 );
1672
1673 // Verify we got a cancellation stop event
1674 assert_eq!(
1675 stop_events(remaining_events),
1676 vec![acp::StopReason::Cancelled],
1677 );
1678
1679 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1680 thread.update(cx, |thread, _cx| {
1681 let message = thread.last_message().unwrap();
1682 let agent_message = message.as_agent_message().unwrap();
1683
1684 let tool_use = agent_message
1685 .content
1686 .iter()
1687 .find_map(|content| match content {
1688 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1689 _ => None,
1690 })
1691 .expect("expected tool use in agent message");
1692
1693 let tool_result = agent_message
1694 .tool_results
1695 .get(&tool_use.id)
1696 .expect("expected tool result");
1697
1698 let result_text = match &tool_result.content {
1699 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
1700 _ => panic!("expected text content in tool result"),
1701 };
1702
1703 // "partial output" comes from FakeTerminalHandle's output field
1704 assert!(
1705 result_text.contains("partial output"),
1706 "expected tool result to contain terminal output, got: {result_text}"
1707 );
1708 // Match the actual format from process_content in terminal_tool.rs
1709 assert!(
1710 result_text.contains("The user stopped this command"),
1711 "expected tool result to indicate user stopped, got: {result_text}"
1712 );
1713 });
1714
1715 // Verify we can send a new message after cancellation
1716 verify_thread_recovery(&thread, &fake_model, cx).await;
1717}
1718
1719#[gpui::test]
1720async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
1721 // This test verifies that tools which properly handle cancellation via
1722 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
1723 // to cancellation and report that they were cancelled.
1724 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1725 always_allow_tools(cx);
1726 let fake_model = model.as_fake();
1727
1728 let (tool, was_cancelled) = CancellationAwareTool::new();
1729
1730 let mut events = thread
1731 .update(cx, |thread, cx| {
1732 thread.add_tool(tool);
1733 thread.send(
1734 UserMessageId::new(),
1735 ["call the cancellation aware tool"],
1736 cx,
1737 )
1738 })
1739 .unwrap();
1740
1741 cx.run_until_parked();
1742
1743 // Simulate the model calling the cancellation-aware tool
1744 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1745 LanguageModelToolUse {
1746 id: "cancellation_aware_1".into(),
1747 name: "cancellation_aware".into(),
1748 raw_input: r#"{}"#.into(),
1749 input: json!({}),
1750 is_input_complete: true,
1751 thought_signature: None,
1752 },
1753 ));
1754 fake_model.end_last_completion_stream();
1755
1756 cx.run_until_parked();
1757
1758 // Wait for the tool call to be reported
1759 let mut tool_started = false;
1760 let deadline = cx.executor().num_cpus() * 100;
1761 for _ in 0..deadline {
1762 cx.run_until_parked();
1763
1764 while let Some(Some(event)) = events.next().now_or_never() {
1765 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
1766 if tool_call.title == "Cancellation Aware Tool" {
1767 tool_started = true;
1768 break;
1769 }
1770 }
1771 }
1772
1773 if tool_started {
1774 break;
1775 }
1776
1777 cx.background_executor
1778 .timer(Duration::from_millis(10))
1779 .await;
1780 }
1781 assert!(tool_started, "expected cancellation aware tool to start");
1782
1783 // Cancel the thread and wait for it to complete
1784 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
1785
1786 // The cancel task should complete promptly because the tool handles cancellation
1787 let timeout = cx.background_executor.timer(Duration::from_secs(5));
1788 futures::select! {
1789 _ = cancel_task.fuse() => {}
1790 _ = timeout.fuse() => {
1791 panic!("cancel task timed out - tool did not respond to cancellation");
1792 }
1793 }
1794
1795 // Verify the tool detected cancellation via its flag
1796 assert!(
1797 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
1798 "tool should have detected cancellation via event_stream.cancelled_by_user()"
1799 );
1800
1801 // Collect remaining events
1802 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1803
1804 // Verify we got a cancellation stop event
1805 assert_eq!(
1806 stop_events(remaining_events),
1807 vec![acp::StopReason::Cancelled],
1808 );
1809
1810 // Verify we can send a new message after cancellation
1811 verify_thread_recovery(&thread, &fake_model, cx).await;
1812}
1813
1814/// Helper to verify thread can recover after cancellation by sending a simple message.
1815async fn verify_thread_recovery(
1816 thread: &Entity<Thread>,
1817 fake_model: &FakeLanguageModel,
1818 cx: &mut TestAppContext,
1819) {
1820 let events = thread
1821 .update(cx, |thread, cx| {
1822 thread.send(
1823 UserMessageId::new(),
1824 ["Testing: reply with 'Hello' then stop."],
1825 cx,
1826 )
1827 })
1828 .unwrap();
1829 cx.run_until_parked();
1830 fake_model.send_last_completion_stream_text_chunk("Hello");
1831 fake_model
1832 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1833 fake_model.end_last_completion_stream();
1834
1835 let events = events.collect::<Vec<_>>().await;
1836 thread.update(cx, |thread, _cx| {
1837 let message = thread.last_message().unwrap();
1838 let agent_message = message.as_agent_message().unwrap();
1839 assert_eq!(
1840 agent_message.content,
1841 vec![AgentMessageContent::Text("Hello".to_string())]
1842 );
1843 });
1844 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1845}
1846
1847/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
1848async fn wait_for_terminal_tool_started(
1849 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1850 cx: &mut TestAppContext,
1851) {
1852 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
1853 for _ in 0..deadline {
1854 cx.run_until_parked();
1855
1856 while let Some(Some(event)) = events.next().now_or_never() {
1857 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1858 update,
1859 ))) = &event
1860 {
1861 if update.fields.content.as_ref().is_some_and(|content| {
1862 content
1863 .iter()
1864 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1865 }) {
1866 return;
1867 }
1868 }
1869 }
1870
1871 cx.background_executor
1872 .timer(Duration::from_millis(10))
1873 .await;
1874 }
1875 panic!("terminal tool did not start within the expected time");
1876}
1877
1878/// Collects events until a Stop event is received, driving the executor to completion.
1879async fn collect_events_until_stop(
1880 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1881 cx: &mut TestAppContext,
1882) -> Vec<Result<ThreadEvent>> {
1883 let mut collected = Vec::new();
1884 let deadline = cx.executor().num_cpus() * 200;
1885
1886 for _ in 0..deadline {
1887 cx.executor().advance_clock(Duration::from_millis(10));
1888 cx.run_until_parked();
1889
1890 while let Some(Some(event)) = events.next().now_or_never() {
1891 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
1892 collected.push(event);
1893 if is_stop {
1894 return collected;
1895 }
1896 }
1897 }
1898 panic!(
1899 "did not receive Stop event within the expected time; collected {} events",
1900 collected.len()
1901 );
1902}
1903
1904#[gpui::test]
1905async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
1906 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1907 always_allow_tools(cx);
1908 let fake_model = model.as_fake();
1909
1910 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1911 let environment = Rc::new(FakeThreadEnvironment {
1912 handle: handle.clone(),
1913 });
1914
1915 let message_id = UserMessageId::new();
1916 let mut events = thread
1917 .update(cx, |thread, cx| {
1918 thread.add_tool(crate::TerminalTool::new(
1919 thread.project().clone(),
1920 environment,
1921 ));
1922 thread.send(message_id.clone(), ["run a command"], cx)
1923 })
1924 .unwrap();
1925
1926 cx.run_until_parked();
1927
1928 // Simulate the model calling the terminal tool
1929 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1930 LanguageModelToolUse {
1931 id: "terminal_tool_1".into(),
1932 name: "terminal".into(),
1933 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1934 input: json!({"command": "sleep 1000", "cd": "."}),
1935 is_input_complete: true,
1936 thought_signature: None,
1937 },
1938 ));
1939 fake_model.end_last_completion_stream();
1940
1941 // Wait for the terminal tool to start running
1942 wait_for_terminal_tool_started(&mut events, cx).await;
1943
1944 // Truncate the thread while the terminal is running
1945 thread
1946 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1947 .unwrap();
1948
1949 // Drive the executor to let cancellation complete
1950 let _ = collect_events_until_stop(&mut events, cx).await;
1951
1952 // Verify the terminal was killed
1953 assert!(
1954 handle.was_killed(),
1955 "expected terminal handle to be killed on truncate"
1956 );
1957
1958 // Verify the thread is empty after truncation
1959 thread.update(cx, |thread, _cx| {
1960 assert_eq!(
1961 thread.to_markdown(),
1962 "",
1963 "expected thread to be empty after truncating the only message"
1964 );
1965 });
1966
1967 // Verify we can send a new message after truncation
1968 verify_thread_recovery(&thread, &fake_model, cx).await;
1969}
1970
1971#[gpui::test]
1972async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
1973 // Tests that cancellation properly kills all running terminal tools when multiple are active.
1974 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1975 always_allow_tools(cx);
1976 let fake_model = model.as_fake();
1977
1978 let environment = Rc::new(MultiTerminalEnvironment::new());
1979
1980 let mut events = thread
1981 .update(cx, |thread, cx| {
1982 thread.add_tool(crate::TerminalTool::new(
1983 thread.project().clone(),
1984 environment.clone(),
1985 ));
1986 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
1987 })
1988 .unwrap();
1989
1990 cx.run_until_parked();
1991
1992 // Simulate the model calling two terminal tools
1993 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1994 LanguageModelToolUse {
1995 id: "terminal_tool_1".into(),
1996 name: "terminal".into(),
1997 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1998 input: json!({"command": "sleep 1000", "cd": "."}),
1999 is_input_complete: true,
2000 thought_signature: None,
2001 },
2002 ));
2003 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2004 LanguageModelToolUse {
2005 id: "terminal_tool_2".into(),
2006 name: "terminal".into(),
2007 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2008 input: json!({"command": "sleep 2000", "cd": "."}),
2009 is_input_complete: true,
2010 thought_signature: None,
2011 },
2012 ));
2013 fake_model.end_last_completion_stream();
2014
2015 // Wait for both terminal tools to start by counting terminal content updates
2016 let mut terminals_started = 0;
2017 let deadline = cx.executor().num_cpus() * 100;
2018 for _ in 0..deadline {
2019 cx.run_until_parked();
2020
2021 while let Some(Some(event)) = events.next().now_or_never() {
2022 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2023 update,
2024 ))) = &event
2025 {
2026 if update.fields.content.as_ref().is_some_and(|content| {
2027 content
2028 .iter()
2029 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2030 }) {
2031 terminals_started += 1;
2032 if terminals_started >= 2 {
2033 break;
2034 }
2035 }
2036 }
2037 }
2038 if terminals_started >= 2 {
2039 break;
2040 }
2041
2042 cx.background_executor
2043 .timer(Duration::from_millis(10))
2044 .await;
2045 }
2046 assert!(
2047 terminals_started >= 2,
2048 "expected 2 terminal tools to start, got {terminals_started}"
2049 );
2050
2051 // Cancel the thread while both terminals are running
2052 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2053
2054 // Collect remaining events
2055 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2056
2057 // Verify both terminal handles were killed
2058 let handles = environment.handles();
2059 assert_eq!(
2060 handles.len(),
2061 2,
2062 "expected 2 terminal handles to be created"
2063 );
2064 assert!(
2065 handles[0].was_killed(),
2066 "expected first terminal handle to be killed on cancellation"
2067 );
2068 assert!(
2069 handles[1].was_killed(),
2070 "expected second terminal handle to be killed on cancellation"
2071 );
2072
2073 // Verify we got a cancellation stop event
2074 assert_eq!(
2075 stop_events(remaining_events),
2076 vec![acp::StopReason::Cancelled],
2077 );
2078}
2079
2080#[gpui::test]
2081async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2082 // Tests that clicking the stop button on the terminal card (as opposed to the main
2083 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2084 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2085 always_allow_tools(cx);
2086 let fake_model = model.as_fake();
2087
2088 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2089 let environment = Rc::new(FakeThreadEnvironment {
2090 handle: handle.clone(),
2091 });
2092
2093 let mut events = thread
2094 .update(cx, |thread, cx| {
2095 thread.add_tool(crate::TerminalTool::new(
2096 thread.project().clone(),
2097 environment,
2098 ));
2099 thread.send(UserMessageId::new(), ["run a command"], cx)
2100 })
2101 .unwrap();
2102
2103 cx.run_until_parked();
2104
2105 // Simulate the model calling the terminal tool
2106 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2107 LanguageModelToolUse {
2108 id: "terminal_tool_1".into(),
2109 name: "terminal".into(),
2110 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2111 input: json!({"command": "sleep 1000", "cd": "."}),
2112 is_input_complete: true,
2113 thought_signature: None,
2114 },
2115 ));
2116 fake_model.end_last_completion_stream();
2117
2118 // Wait for the terminal tool to start running
2119 wait_for_terminal_tool_started(&mut events, cx).await;
2120
2121 // Simulate user clicking stop on the terminal card itself.
2122 // This sets the flag and signals exit (simulating what the real UI would do).
2123 handle.set_stopped_by_user(true);
2124 handle.killed.store(true, Ordering::SeqCst);
2125 handle.signal_exit();
2126
2127 // Wait for the tool to complete
2128 cx.run_until_parked();
2129
2130 // The thread continues after tool completion - simulate the model ending its turn
2131 fake_model
2132 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2133 fake_model.end_last_completion_stream();
2134
2135 // Collect remaining events
2136 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2137
2138 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2139 assert_eq!(
2140 stop_events(remaining_events),
2141 vec![acp::StopReason::EndTurn],
2142 );
2143
2144 // Verify the tool result indicates user stopped
2145 thread.update(cx, |thread, _cx| {
2146 let message = thread.last_message().unwrap();
2147 let agent_message = message.as_agent_message().unwrap();
2148
2149 let tool_use = agent_message
2150 .content
2151 .iter()
2152 .find_map(|content| match content {
2153 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2154 _ => None,
2155 })
2156 .expect("expected tool use in agent message");
2157
2158 let tool_result = agent_message
2159 .tool_results
2160 .get(&tool_use.id)
2161 .expect("expected tool result");
2162
2163 let result_text = match &tool_result.content {
2164 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2165 _ => panic!("expected text content in tool result"),
2166 };
2167
2168 assert!(
2169 result_text.contains("The user stopped this command"),
2170 "expected tool result to indicate user stopped, got: {result_text}"
2171 );
2172 });
2173}
2174
2175#[gpui::test]
2176async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2177 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2178 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2179 always_allow_tools(cx);
2180 let fake_model = model.as_fake();
2181
2182 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2183 let environment = Rc::new(FakeThreadEnvironment {
2184 handle: handle.clone(),
2185 });
2186
2187 let mut events = thread
2188 .update(cx, |thread, cx| {
2189 thread.add_tool(crate::TerminalTool::new(
2190 thread.project().clone(),
2191 environment,
2192 ));
2193 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2194 })
2195 .unwrap();
2196
2197 cx.run_until_parked();
2198
2199 // Simulate the model calling the terminal tool with a short timeout
2200 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2201 LanguageModelToolUse {
2202 id: "terminal_tool_1".into(),
2203 name: "terminal".into(),
2204 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2205 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2206 is_input_complete: true,
2207 thought_signature: None,
2208 },
2209 ));
2210 fake_model.end_last_completion_stream();
2211
2212 // Wait for the terminal tool to start running
2213 wait_for_terminal_tool_started(&mut events, cx).await;
2214
2215 // Advance clock past the timeout
2216 cx.executor().advance_clock(Duration::from_millis(200));
2217 cx.run_until_parked();
2218
2219 // The thread continues after tool completion - simulate the model ending its turn
2220 fake_model
2221 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2222 fake_model.end_last_completion_stream();
2223
2224 // Collect remaining events
2225 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2226
2227 // Verify the terminal was killed due to timeout
2228 assert!(
2229 handle.was_killed(),
2230 "expected terminal handle to be killed on timeout"
2231 );
2232
2233 // Verify we got an EndTurn (the tool completed, just with timeout)
2234 assert_eq!(
2235 stop_events(remaining_events),
2236 vec![acp::StopReason::EndTurn],
2237 );
2238
2239 // Verify the tool result indicates timeout, not user stopped
2240 thread.update(cx, |thread, _cx| {
2241 let message = thread.last_message().unwrap();
2242 let agent_message = message.as_agent_message().unwrap();
2243
2244 let tool_use = agent_message
2245 .content
2246 .iter()
2247 .find_map(|content| match content {
2248 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2249 _ => None,
2250 })
2251 .expect("expected tool use in agent message");
2252
2253 let tool_result = agent_message
2254 .tool_results
2255 .get(&tool_use.id)
2256 .expect("expected tool result");
2257
2258 let result_text = match &tool_result.content {
2259 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2260 _ => panic!("expected text content in tool result"),
2261 };
2262
2263 assert!(
2264 result_text.contains("timed out"),
2265 "expected tool result to indicate timeout, got: {result_text}"
2266 );
2267 assert!(
2268 !result_text.contains("The user stopped"),
2269 "tool result should not mention user stopped when it timed out, got: {result_text}"
2270 );
2271 });
2272}
2273
2274#[gpui::test]
2275async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2276 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2277 let fake_model = model.as_fake();
2278
2279 let events_1 = thread
2280 .update(cx, |thread, cx| {
2281 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2282 })
2283 .unwrap();
2284 cx.run_until_parked();
2285 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2286 cx.run_until_parked();
2287
2288 let events_2 = thread
2289 .update(cx, |thread, cx| {
2290 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2291 })
2292 .unwrap();
2293 cx.run_until_parked();
2294 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2295 fake_model
2296 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2297 fake_model.end_last_completion_stream();
2298
2299 let events_1 = events_1.collect::<Vec<_>>().await;
2300 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2301 let events_2 = events_2.collect::<Vec<_>>().await;
2302 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2303}
2304
2305#[gpui::test]
2306async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2307 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2308 let fake_model = model.as_fake();
2309
2310 let events_1 = thread
2311 .update(cx, |thread, cx| {
2312 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2313 })
2314 .unwrap();
2315 cx.run_until_parked();
2316 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2317 fake_model
2318 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2319 fake_model.end_last_completion_stream();
2320 let events_1 = events_1.collect::<Vec<_>>().await;
2321
2322 let events_2 = thread
2323 .update(cx, |thread, cx| {
2324 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2325 })
2326 .unwrap();
2327 cx.run_until_parked();
2328 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2329 fake_model
2330 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2331 fake_model.end_last_completion_stream();
2332 let events_2 = events_2.collect::<Vec<_>>().await;
2333
2334 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2335 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2336}
2337
2338#[gpui::test]
2339async fn test_refusal(cx: &mut TestAppContext) {
2340 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2341 let fake_model = model.as_fake();
2342
2343 let events = thread
2344 .update(cx, |thread, cx| {
2345 thread.send(UserMessageId::new(), ["Hello"], cx)
2346 })
2347 .unwrap();
2348 cx.run_until_parked();
2349 thread.read_with(cx, |thread, _| {
2350 assert_eq!(
2351 thread.to_markdown(),
2352 indoc! {"
2353 ## User
2354
2355 Hello
2356 "}
2357 );
2358 });
2359
2360 fake_model.send_last_completion_stream_text_chunk("Hey!");
2361 cx.run_until_parked();
2362 thread.read_with(cx, |thread, _| {
2363 assert_eq!(
2364 thread.to_markdown(),
2365 indoc! {"
2366 ## User
2367
2368 Hello
2369
2370 ## Assistant
2371
2372 Hey!
2373 "}
2374 );
2375 });
2376
2377 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2378 fake_model
2379 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2380 let events = events.collect::<Vec<_>>().await;
2381 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2382 thread.read_with(cx, |thread, _| {
2383 assert_eq!(thread.to_markdown(), "");
2384 });
2385}
2386
2387#[gpui::test]
2388async fn test_truncate_first_message(cx: &mut TestAppContext) {
2389 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2390 let fake_model = model.as_fake();
2391
2392 let message_id = UserMessageId::new();
2393 thread
2394 .update(cx, |thread, cx| {
2395 thread.send(message_id.clone(), ["Hello"], cx)
2396 })
2397 .unwrap();
2398 cx.run_until_parked();
2399 thread.read_with(cx, |thread, _| {
2400 assert_eq!(
2401 thread.to_markdown(),
2402 indoc! {"
2403 ## User
2404
2405 Hello
2406 "}
2407 );
2408 assert_eq!(thread.latest_token_usage(), None);
2409 });
2410
2411 fake_model.send_last_completion_stream_text_chunk("Hey!");
2412 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2413 language_model::TokenUsage {
2414 input_tokens: 32_000,
2415 output_tokens: 16_000,
2416 cache_creation_input_tokens: 0,
2417 cache_read_input_tokens: 0,
2418 },
2419 ));
2420 cx.run_until_parked();
2421 thread.read_with(cx, |thread, _| {
2422 assert_eq!(
2423 thread.to_markdown(),
2424 indoc! {"
2425 ## User
2426
2427 Hello
2428
2429 ## Assistant
2430
2431 Hey!
2432 "}
2433 );
2434 assert_eq!(
2435 thread.latest_token_usage(),
2436 Some(acp_thread::TokenUsage {
2437 used_tokens: 32_000 + 16_000,
2438 max_tokens: 1_000_000,
2439 input_tokens: 32_000,
2440 output_tokens: 16_000,
2441 })
2442 );
2443 });
2444
2445 thread
2446 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2447 .unwrap();
2448 cx.run_until_parked();
2449 thread.read_with(cx, |thread, _| {
2450 assert_eq!(thread.to_markdown(), "");
2451 assert_eq!(thread.latest_token_usage(), None);
2452 });
2453
2454 // Ensure we can still send a new message after truncation.
2455 thread
2456 .update(cx, |thread, cx| {
2457 thread.send(UserMessageId::new(), ["Hi"], cx)
2458 })
2459 .unwrap();
2460 thread.update(cx, |thread, _cx| {
2461 assert_eq!(
2462 thread.to_markdown(),
2463 indoc! {"
2464 ## User
2465
2466 Hi
2467 "}
2468 );
2469 });
2470 cx.run_until_parked();
2471 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2472 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2473 language_model::TokenUsage {
2474 input_tokens: 40_000,
2475 output_tokens: 20_000,
2476 cache_creation_input_tokens: 0,
2477 cache_read_input_tokens: 0,
2478 },
2479 ));
2480 cx.run_until_parked();
2481 thread.read_with(cx, |thread, _| {
2482 assert_eq!(
2483 thread.to_markdown(),
2484 indoc! {"
2485 ## User
2486
2487 Hi
2488
2489 ## Assistant
2490
2491 Ahoy!
2492 "}
2493 );
2494
2495 assert_eq!(
2496 thread.latest_token_usage(),
2497 Some(acp_thread::TokenUsage {
2498 used_tokens: 40_000 + 20_000,
2499 max_tokens: 1_000_000,
2500 input_tokens: 40_000,
2501 output_tokens: 20_000,
2502 })
2503 );
2504 });
2505}
2506
2507#[gpui::test]
2508async fn test_truncate_second_message(cx: &mut TestAppContext) {
2509 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2510 let fake_model = model.as_fake();
2511
2512 thread
2513 .update(cx, |thread, cx| {
2514 thread.send(UserMessageId::new(), ["Message 1"], cx)
2515 })
2516 .unwrap();
2517 cx.run_until_parked();
2518 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2519 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2520 language_model::TokenUsage {
2521 input_tokens: 32_000,
2522 output_tokens: 16_000,
2523 cache_creation_input_tokens: 0,
2524 cache_read_input_tokens: 0,
2525 },
2526 ));
2527 fake_model.end_last_completion_stream();
2528 cx.run_until_parked();
2529
2530 let assert_first_message_state = |cx: &mut TestAppContext| {
2531 thread.clone().read_with(cx, |thread, _| {
2532 assert_eq!(
2533 thread.to_markdown(),
2534 indoc! {"
2535 ## User
2536
2537 Message 1
2538
2539 ## Assistant
2540
2541 Message 1 response
2542 "}
2543 );
2544
2545 assert_eq!(
2546 thread.latest_token_usage(),
2547 Some(acp_thread::TokenUsage {
2548 used_tokens: 32_000 + 16_000,
2549 max_tokens: 1_000_000,
2550 input_tokens: 32_000,
2551 output_tokens: 16_000,
2552 })
2553 );
2554 });
2555 };
2556
2557 assert_first_message_state(cx);
2558
2559 let second_message_id = UserMessageId::new();
2560 thread
2561 .update(cx, |thread, cx| {
2562 thread.send(second_message_id.clone(), ["Message 2"], cx)
2563 })
2564 .unwrap();
2565 cx.run_until_parked();
2566
2567 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2568 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2569 language_model::TokenUsage {
2570 input_tokens: 40_000,
2571 output_tokens: 20_000,
2572 cache_creation_input_tokens: 0,
2573 cache_read_input_tokens: 0,
2574 },
2575 ));
2576 fake_model.end_last_completion_stream();
2577 cx.run_until_parked();
2578
2579 thread.read_with(cx, |thread, _| {
2580 assert_eq!(
2581 thread.to_markdown(),
2582 indoc! {"
2583 ## User
2584
2585 Message 1
2586
2587 ## Assistant
2588
2589 Message 1 response
2590
2591 ## User
2592
2593 Message 2
2594
2595 ## Assistant
2596
2597 Message 2 response
2598 "}
2599 );
2600
2601 assert_eq!(
2602 thread.latest_token_usage(),
2603 Some(acp_thread::TokenUsage {
2604 used_tokens: 40_000 + 20_000,
2605 max_tokens: 1_000_000,
2606 input_tokens: 40_000,
2607 output_tokens: 20_000,
2608 })
2609 );
2610 });
2611
2612 thread
2613 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2614 .unwrap();
2615 cx.run_until_parked();
2616
2617 assert_first_message_state(cx);
2618}
2619
2620#[gpui::test]
2621async fn test_title_generation(cx: &mut TestAppContext) {
2622 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2623 let fake_model = model.as_fake();
2624
2625 let summary_model = Arc::new(FakeLanguageModel::default());
2626 thread.update(cx, |thread, cx| {
2627 thread.set_summarization_model(Some(summary_model.clone()), cx)
2628 });
2629
2630 let send = thread
2631 .update(cx, |thread, cx| {
2632 thread.send(UserMessageId::new(), ["Hello"], cx)
2633 })
2634 .unwrap();
2635 cx.run_until_parked();
2636
2637 fake_model.send_last_completion_stream_text_chunk("Hey!");
2638 fake_model.end_last_completion_stream();
2639 cx.run_until_parked();
2640 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2641
2642 // Ensure the summary model has been invoked to generate a title.
2643 summary_model.send_last_completion_stream_text_chunk("Hello ");
2644 summary_model.send_last_completion_stream_text_chunk("world\nG");
2645 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2646 summary_model.end_last_completion_stream();
2647 send.collect::<Vec<_>>().await;
2648 cx.run_until_parked();
2649 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2650
2651 // Send another message, ensuring no title is generated this time.
2652 let send = thread
2653 .update(cx, |thread, cx| {
2654 thread.send(UserMessageId::new(), ["Hello again"], cx)
2655 })
2656 .unwrap();
2657 cx.run_until_parked();
2658 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2659 fake_model.end_last_completion_stream();
2660 cx.run_until_parked();
2661 assert_eq!(summary_model.pending_completions(), Vec::new());
2662 send.collect::<Vec<_>>().await;
2663 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2664}
2665
2666#[gpui::test]
2667async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2668 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2669 let fake_model = model.as_fake();
2670
2671 let _events = thread
2672 .update(cx, |thread, cx| {
2673 thread.add_tool(ToolRequiringPermission);
2674 thread.add_tool(EchoTool);
2675 thread.send(UserMessageId::new(), ["Hey!"], cx)
2676 })
2677 .unwrap();
2678 cx.run_until_parked();
2679
2680 let permission_tool_use = LanguageModelToolUse {
2681 id: "tool_id_1".into(),
2682 name: ToolRequiringPermission::name().into(),
2683 raw_input: "{}".into(),
2684 input: json!({}),
2685 is_input_complete: true,
2686 thought_signature: None,
2687 };
2688 let echo_tool_use = LanguageModelToolUse {
2689 id: "tool_id_2".into(),
2690 name: EchoTool::name().into(),
2691 raw_input: json!({"text": "test"}).to_string(),
2692 input: json!({"text": "test"}),
2693 is_input_complete: true,
2694 thought_signature: None,
2695 };
2696 fake_model.send_last_completion_stream_text_chunk("Hi!");
2697 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2698 permission_tool_use,
2699 ));
2700 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2701 echo_tool_use.clone(),
2702 ));
2703 fake_model.end_last_completion_stream();
2704 cx.run_until_parked();
2705
2706 // Ensure pending tools are skipped when building a request.
2707 let request = thread
2708 .read_with(cx, |thread, cx| {
2709 thread.build_completion_request(CompletionIntent::EditFile, cx)
2710 })
2711 .unwrap();
2712 assert_eq!(
2713 request.messages[1..],
2714 vec![
2715 LanguageModelRequestMessage {
2716 role: Role::User,
2717 content: vec!["Hey!".into()],
2718 cache: true,
2719 reasoning_details: None,
2720 },
2721 LanguageModelRequestMessage {
2722 role: Role::Assistant,
2723 content: vec![
2724 MessageContent::Text("Hi!".into()),
2725 MessageContent::ToolUse(echo_tool_use.clone())
2726 ],
2727 cache: false,
2728 reasoning_details: None,
2729 },
2730 LanguageModelRequestMessage {
2731 role: Role::User,
2732 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2733 tool_use_id: echo_tool_use.id.clone(),
2734 tool_name: echo_tool_use.name,
2735 is_error: false,
2736 content: "test".into(),
2737 output: Some("test".into())
2738 })],
2739 cache: false,
2740 reasoning_details: None,
2741 },
2742 ],
2743 );
2744}
2745
2746#[gpui::test]
2747async fn test_agent_connection(cx: &mut TestAppContext) {
2748 cx.update(settings::init);
2749 let templates = Templates::new();
2750
2751 // Initialize language model system with test provider
2752 cx.update(|cx| {
2753 gpui_tokio::init(cx);
2754
2755 let http_client = FakeHttpClient::with_404_response();
2756 let clock = Arc::new(clock::FakeSystemClock::new());
2757 let client = Client::new(clock, http_client, cx);
2758 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2759 language_model::init(client.clone(), cx);
2760 language_models::init(user_store, client.clone(), cx);
2761 LanguageModelRegistry::test(cx);
2762 });
2763 cx.executor().forbid_parking();
2764
2765 // Create a project for new_thread
2766 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2767 fake_fs.insert_tree(path!("/test"), json!({})).await;
2768 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2769 let cwd = Path::new("/test");
2770 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2771
2772 // Create agent and connection
2773 let agent = NativeAgent::new(
2774 project.clone(),
2775 thread_store,
2776 templates.clone(),
2777 None,
2778 fake_fs.clone(),
2779 &mut cx.to_async(),
2780 )
2781 .await
2782 .unwrap();
2783 let connection = NativeAgentConnection(agent.clone());
2784
2785 // Create a thread using new_thread
2786 let connection_rc = Rc::new(connection.clone());
2787 let acp_thread = cx
2788 .update(|cx| connection_rc.new_thread(project, cwd, cx))
2789 .await
2790 .expect("new_thread should succeed");
2791
2792 // Get the session_id from the AcpThread
2793 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2794
2795 // Test model_selector returns Some
2796 let selector_opt = connection.model_selector(&session_id);
2797 assert!(
2798 selector_opt.is_some(),
2799 "agent should always support ModelSelector"
2800 );
2801 let selector = selector_opt.unwrap();
2802
2803 // Test list_models
2804 let listed_models = cx
2805 .update(|cx| selector.list_models(cx))
2806 .await
2807 .expect("list_models should succeed");
2808 let AgentModelList::Grouped(listed_models) = listed_models else {
2809 panic!("Unexpected model list type");
2810 };
2811 assert!(!listed_models.is_empty(), "should have at least one model");
2812 assert_eq!(
2813 listed_models[&AgentModelGroupName("Fake".into())][0]
2814 .id
2815 .0
2816 .as_ref(),
2817 "fake/fake"
2818 );
2819
2820 // Test selected_model returns the default
2821 let model = cx
2822 .update(|cx| selector.selected_model(cx))
2823 .await
2824 .expect("selected_model should succeed");
2825 let model = cx
2826 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2827 .unwrap();
2828 let model = model.as_fake();
2829 assert_eq!(model.id().0, "fake", "should return default model");
2830
2831 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2832 cx.run_until_parked();
2833 model.send_last_completion_stream_text_chunk("def");
2834 cx.run_until_parked();
2835 acp_thread.read_with(cx, |thread, cx| {
2836 assert_eq!(
2837 thread.to_markdown(cx),
2838 indoc! {"
2839 ## User
2840
2841 abc
2842
2843 ## Assistant
2844
2845 def
2846
2847 "}
2848 )
2849 });
2850
2851 // Test cancel
2852 cx.update(|cx| connection.cancel(&session_id, cx));
2853 request.await.expect("prompt should fail gracefully");
2854
2855 // Ensure that dropping the ACP thread causes the native thread to be
2856 // dropped as well.
2857 cx.update(|_| drop(acp_thread));
2858 let result = cx
2859 .update(|cx| {
2860 connection.prompt(
2861 Some(acp_thread::UserMessageId::new()),
2862 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2863 cx,
2864 )
2865 })
2866 .await;
2867 assert_eq!(
2868 result.as_ref().unwrap_err().to_string(),
2869 "Session not found",
2870 "unexpected result: {:?}",
2871 result
2872 );
2873}
2874
2875#[gpui::test]
2876async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2877 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2878 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2879 let fake_model = model.as_fake();
2880
2881 let mut events = thread
2882 .update(cx, |thread, cx| {
2883 thread.send(UserMessageId::new(), ["Think"], cx)
2884 })
2885 .unwrap();
2886 cx.run_until_parked();
2887
2888 // Simulate streaming partial input.
2889 let input = json!({});
2890 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2891 LanguageModelToolUse {
2892 id: "1".into(),
2893 name: ThinkingTool::name().into(),
2894 raw_input: input.to_string(),
2895 input,
2896 is_input_complete: false,
2897 thought_signature: None,
2898 },
2899 ));
2900
2901 // Input streaming completed
2902 let input = json!({ "content": "Thinking hard!" });
2903 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2904 LanguageModelToolUse {
2905 id: "1".into(),
2906 name: "thinking".into(),
2907 raw_input: input.to_string(),
2908 input,
2909 is_input_complete: true,
2910 thought_signature: None,
2911 },
2912 ));
2913 fake_model.end_last_completion_stream();
2914 cx.run_until_parked();
2915
2916 let tool_call = expect_tool_call(&mut events).await;
2917 assert_eq!(
2918 tool_call,
2919 acp::ToolCall::new("1", "Thinking")
2920 .kind(acp::ToolKind::Think)
2921 .raw_input(json!({}))
2922 .meta(acp::Meta::from_iter([(
2923 "tool_name".into(),
2924 "thinking".into()
2925 )]))
2926 );
2927 let update = expect_tool_call_update_fields(&mut events).await;
2928 assert_eq!(
2929 update,
2930 acp::ToolCallUpdate::new(
2931 "1",
2932 acp::ToolCallUpdateFields::new()
2933 .title("Thinking")
2934 .kind(acp::ToolKind::Think)
2935 .raw_input(json!({ "content": "Thinking hard!"}))
2936 )
2937 );
2938 let update = expect_tool_call_update_fields(&mut events).await;
2939 assert_eq!(
2940 update,
2941 acp::ToolCallUpdate::new(
2942 "1",
2943 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
2944 )
2945 );
2946 let update = expect_tool_call_update_fields(&mut events).await;
2947 assert_eq!(
2948 update,
2949 acp::ToolCallUpdate::new(
2950 "1",
2951 acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
2952 )
2953 );
2954 let update = expect_tool_call_update_fields(&mut events).await;
2955 assert_eq!(
2956 update,
2957 acp::ToolCallUpdate::new(
2958 "1",
2959 acp::ToolCallUpdateFields::new()
2960 .status(acp::ToolCallStatus::Completed)
2961 .raw_output("Finished thinking.")
2962 )
2963 );
2964}
2965
2966#[gpui::test]
2967async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2968 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2969 let fake_model = model.as_fake();
2970
2971 let mut events = thread
2972 .update(cx, |thread, cx| {
2973 thread.send(UserMessageId::new(), ["Hello!"], cx)
2974 })
2975 .unwrap();
2976 cx.run_until_parked();
2977
2978 fake_model.send_last_completion_stream_text_chunk("Hey!");
2979 fake_model.end_last_completion_stream();
2980
2981 let mut retry_events = Vec::new();
2982 while let Some(Ok(event)) = events.next().await {
2983 match event {
2984 ThreadEvent::Retry(retry_status) => {
2985 retry_events.push(retry_status);
2986 }
2987 ThreadEvent::Stop(..) => break,
2988 _ => {}
2989 }
2990 }
2991
2992 assert_eq!(retry_events.len(), 0);
2993 thread.read_with(cx, |thread, _cx| {
2994 assert_eq!(
2995 thread.to_markdown(),
2996 indoc! {"
2997 ## User
2998
2999 Hello!
3000
3001 ## Assistant
3002
3003 Hey!
3004 "}
3005 )
3006 });
3007}
3008
3009#[gpui::test]
3010async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3011 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3012 let fake_model = model.as_fake();
3013
3014 let mut events = thread
3015 .update(cx, |thread, cx| {
3016 thread.send(UserMessageId::new(), ["Hello!"], cx)
3017 })
3018 .unwrap();
3019 cx.run_until_parked();
3020
3021 fake_model.send_last_completion_stream_text_chunk("Hey,");
3022 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3023 provider: LanguageModelProviderName::new("Anthropic"),
3024 retry_after: Some(Duration::from_secs(3)),
3025 });
3026 fake_model.end_last_completion_stream();
3027
3028 cx.executor().advance_clock(Duration::from_secs(3));
3029 cx.run_until_parked();
3030
3031 fake_model.send_last_completion_stream_text_chunk("there!");
3032 fake_model.end_last_completion_stream();
3033 cx.run_until_parked();
3034
3035 let mut retry_events = Vec::new();
3036 while let Some(Ok(event)) = events.next().await {
3037 match event {
3038 ThreadEvent::Retry(retry_status) => {
3039 retry_events.push(retry_status);
3040 }
3041 ThreadEvent::Stop(..) => break,
3042 _ => {}
3043 }
3044 }
3045
3046 assert_eq!(retry_events.len(), 1);
3047 assert!(matches!(
3048 retry_events[0],
3049 acp_thread::RetryStatus { attempt: 1, .. }
3050 ));
3051 thread.read_with(cx, |thread, _cx| {
3052 assert_eq!(
3053 thread.to_markdown(),
3054 indoc! {"
3055 ## User
3056
3057 Hello!
3058
3059 ## Assistant
3060
3061 Hey,
3062
3063 [resume]
3064
3065 ## Assistant
3066
3067 there!
3068 "}
3069 )
3070 });
3071}
3072
3073#[gpui::test]
3074async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3075 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3076 let fake_model = model.as_fake();
3077
3078 let events = thread
3079 .update(cx, |thread, cx| {
3080 thread.add_tool(EchoTool);
3081 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3082 })
3083 .unwrap();
3084 cx.run_until_parked();
3085
3086 let tool_use_1 = LanguageModelToolUse {
3087 id: "tool_1".into(),
3088 name: EchoTool::name().into(),
3089 raw_input: json!({"text": "test"}).to_string(),
3090 input: json!({"text": "test"}),
3091 is_input_complete: true,
3092 thought_signature: None,
3093 };
3094 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3095 tool_use_1.clone(),
3096 ));
3097 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3098 provider: LanguageModelProviderName::new("Anthropic"),
3099 retry_after: Some(Duration::from_secs(3)),
3100 });
3101 fake_model.end_last_completion_stream();
3102
3103 cx.executor().advance_clock(Duration::from_secs(3));
3104 let completion = fake_model.pending_completions().pop().unwrap();
3105 assert_eq!(
3106 completion.messages[1..],
3107 vec![
3108 LanguageModelRequestMessage {
3109 role: Role::User,
3110 content: vec!["Call the echo tool!".into()],
3111 cache: false,
3112 reasoning_details: None,
3113 },
3114 LanguageModelRequestMessage {
3115 role: Role::Assistant,
3116 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3117 cache: false,
3118 reasoning_details: None,
3119 },
3120 LanguageModelRequestMessage {
3121 role: Role::User,
3122 content: vec![language_model::MessageContent::ToolResult(
3123 LanguageModelToolResult {
3124 tool_use_id: tool_use_1.id.clone(),
3125 tool_name: tool_use_1.name.clone(),
3126 is_error: false,
3127 content: "test".into(),
3128 output: Some("test".into())
3129 }
3130 )],
3131 cache: true,
3132 reasoning_details: None,
3133 },
3134 ]
3135 );
3136
3137 fake_model.send_last_completion_stream_text_chunk("Done");
3138 fake_model.end_last_completion_stream();
3139 cx.run_until_parked();
3140 events.collect::<Vec<_>>().await;
3141 thread.read_with(cx, |thread, _cx| {
3142 assert_eq!(
3143 thread.last_message(),
3144 Some(Message::Agent(AgentMessage {
3145 content: vec![AgentMessageContent::Text("Done".into())],
3146 tool_results: IndexMap::default(),
3147 reasoning_details: None,
3148 }))
3149 );
3150 })
3151}
3152
3153#[gpui::test]
3154async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3155 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3156 let fake_model = model.as_fake();
3157
3158 let mut events = thread
3159 .update(cx, |thread, cx| {
3160 thread.send(UserMessageId::new(), ["Hello!"], cx)
3161 })
3162 .unwrap();
3163 cx.run_until_parked();
3164
3165 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3166 fake_model.send_last_completion_stream_error(
3167 LanguageModelCompletionError::ServerOverloaded {
3168 provider: LanguageModelProviderName::new("Anthropic"),
3169 retry_after: Some(Duration::from_secs(3)),
3170 },
3171 );
3172 fake_model.end_last_completion_stream();
3173 cx.executor().advance_clock(Duration::from_secs(3));
3174 cx.run_until_parked();
3175 }
3176
3177 let mut errors = Vec::new();
3178 let mut retry_events = Vec::new();
3179 while let Some(event) = events.next().await {
3180 match event {
3181 Ok(ThreadEvent::Retry(retry_status)) => {
3182 retry_events.push(retry_status);
3183 }
3184 Ok(ThreadEvent::Stop(..)) => break,
3185 Err(error) => errors.push(error),
3186 _ => {}
3187 }
3188 }
3189
3190 assert_eq!(
3191 retry_events.len(),
3192 crate::thread::MAX_RETRY_ATTEMPTS as usize
3193 );
3194 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3195 assert_eq!(retry_events[i].attempt, i + 1);
3196 }
3197 assert_eq!(errors.len(), 1);
3198 let error = errors[0]
3199 .downcast_ref::<LanguageModelCompletionError>()
3200 .unwrap();
3201 assert!(matches!(
3202 error,
3203 LanguageModelCompletionError::ServerOverloaded { .. }
3204 ));
3205}
3206
3207/// Filters out the stop events for asserting against in tests
3208fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3209 result_events
3210 .into_iter()
3211 .filter_map(|event| match event.unwrap() {
3212 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3213 _ => None,
3214 })
3215 .collect()
3216}
3217
3218struct ThreadTest {
3219 model: Arc<dyn LanguageModel>,
3220 thread: Entity<Thread>,
3221 project_context: Entity<ProjectContext>,
3222 context_server_store: Entity<ContextServerStore>,
3223 fs: Arc<FakeFs>,
3224}
3225
3226enum TestModel {
3227 Sonnet4,
3228 Fake,
3229}
3230
3231impl TestModel {
3232 fn id(&self) -> LanguageModelId {
3233 match self {
3234 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3235 TestModel::Fake => unreachable!(),
3236 }
3237 }
3238}
3239
3240async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3241 cx.executor().allow_parking();
3242
3243 let fs = FakeFs::new(cx.background_executor.clone());
3244 fs.create_dir(paths::settings_file().parent().unwrap())
3245 .await
3246 .unwrap();
3247 fs.insert_file(
3248 paths::settings_file(),
3249 json!({
3250 "agent": {
3251 "default_profile": "test-profile",
3252 "profiles": {
3253 "test-profile": {
3254 "name": "Test Profile",
3255 "tools": {
3256 EchoTool::name(): true,
3257 DelayTool::name(): true,
3258 WordListTool::name(): true,
3259 ToolRequiringPermission::name(): true,
3260 InfiniteTool::name(): true,
3261 CancellationAwareTool::name(): true,
3262 ThinkingTool::name(): true,
3263 "terminal": true,
3264 }
3265 }
3266 }
3267 }
3268 })
3269 .to_string()
3270 .into_bytes(),
3271 )
3272 .await;
3273
3274 cx.update(|cx| {
3275 settings::init(cx);
3276
3277 match model {
3278 TestModel::Fake => {}
3279 TestModel::Sonnet4 => {
3280 gpui_tokio::init(cx);
3281 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3282 cx.set_http_client(Arc::new(http_client));
3283 let client = Client::production(cx);
3284 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3285 language_model::init(client.clone(), cx);
3286 language_models::init(user_store, client.clone(), cx);
3287 }
3288 };
3289
3290 watch_settings(fs.clone(), cx);
3291 });
3292
3293 let templates = Templates::new();
3294
3295 fs.insert_tree(path!("/test"), json!({})).await;
3296 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3297
3298 let model = cx
3299 .update(|cx| {
3300 if let TestModel::Fake = model {
3301 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3302 } else {
3303 let model_id = model.id();
3304 let models = LanguageModelRegistry::read_global(cx);
3305 let model = models
3306 .available_models(cx)
3307 .find(|model| model.id() == model_id)
3308 .unwrap();
3309
3310 let provider = models.provider(&model.provider_id()).unwrap();
3311 let authenticated = provider.authenticate(cx);
3312
3313 cx.spawn(async move |_cx| {
3314 authenticated.await.unwrap();
3315 model
3316 })
3317 }
3318 })
3319 .await;
3320
3321 let project_context = cx.new(|_cx| ProjectContext::default());
3322 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3323 let context_server_registry =
3324 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3325 let thread = cx.new(|cx| {
3326 Thread::new(
3327 project,
3328 project_context.clone(),
3329 context_server_registry,
3330 templates,
3331 Some(model.clone()),
3332 cx,
3333 )
3334 });
3335 ThreadTest {
3336 model,
3337 thread,
3338 project_context,
3339 context_server_store,
3340 fs,
3341 }
3342}
3343
3344#[cfg(test)]
3345#[ctor::ctor]
3346fn init_logger() {
3347 if std::env::var("RUST_LOG").is_ok() {
3348 env_logger::init();
3349 }
3350}
3351
3352fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3353 let fs = fs.clone();
3354 cx.spawn({
3355 async move |cx| {
3356 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3357 cx.background_executor(),
3358 fs,
3359 paths::settings_file().clone(),
3360 );
3361 let _watcher_task = watcher_task;
3362
3363 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3364 cx.update(|cx| {
3365 SettingsStore::update_global(cx, |settings, cx| {
3366 settings.set_user_settings(&new_settings_content, cx)
3367 })
3368 })
3369 .ok();
3370 }
3371 }
3372 })
3373 .detach();
3374}
3375
3376fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3377 completion
3378 .tools
3379 .iter()
3380 .map(|tool| tool.name.clone())
3381 .collect()
3382}
3383
3384fn setup_context_server(
3385 name: &'static str,
3386 tools: Vec<context_server::types::Tool>,
3387 context_server_store: &Entity<ContextServerStore>,
3388 cx: &mut TestAppContext,
3389) -> mpsc::UnboundedReceiver<(
3390 context_server::types::CallToolParams,
3391 oneshot::Sender<context_server::types::CallToolResponse>,
3392)> {
3393 cx.update(|cx| {
3394 let mut settings = ProjectSettings::get_global(cx).clone();
3395 settings.context_servers.insert(
3396 name.into(),
3397 project::project_settings::ContextServerSettings::Stdio {
3398 enabled: true,
3399 remote: false,
3400 command: ContextServerCommand {
3401 path: "somebinary".into(),
3402 args: Vec::new(),
3403 env: None,
3404 timeout: None,
3405 },
3406 },
3407 );
3408 ProjectSettings::override_global(settings, cx);
3409 });
3410
3411 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3412 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3413 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3414 context_server::types::InitializeResponse {
3415 protocol_version: context_server::types::ProtocolVersion(
3416 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3417 ),
3418 server_info: context_server::types::Implementation {
3419 name: name.into(),
3420 version: "1.0.0".to_string(),
3421 },
3422 capabilities: context_server::types::ServerCapabilities {
3423 tools: Some(context_server::types::ToolsCapabilities {
3424 list_changed: Some(true),
3425 }),
3426 ..Default::default()
3427 },
3428 meta: None,
3429 }
3430 })
3431 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3432 let tools = tools.clone();
3433 async move {
3434 context_server::types::ListToolsResponse {
3435 tools,
3436 next_cursor: None,
3437 meta: None,
3438 }
3439 }
3440 })
3441 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3442 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3443 async move {
3444 let (response_tx, response_rx) = oneshot::channel();
3445 mcp_tool_calls_tx
3446 .unbounded_send((params, response_tx))
3447 .unwrap();
3448 response_rx.await.unwrap()
3449 }
3450 });
3451 context_server_store.update(cx, |store, cx| {
3452 store.start_server(
3453 Arc::new(ContextServer::new(
3454 ContextServerId(name.into()),
3455 Arc::new(fake_transport),
3456 )),
3457 cx,
3458 );
3459 });
3460 cx.run_until_parked();
3461 mcp_tool_calls_rx
3462}
3463
3464#[gpui::test]
3465async fn test_tokens_before_message(cx: &mut TestAppContext) {
3466 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3467 let fake_model = model.as_fake();
3468
3469 // First message
3470 let message_1_id = UserMessageId::new();
3471 thread
3472 .update(cx, |thread, cx| {
3473 thread.send(message_1_id.clone(), ["First message"], cx)
3474 })
3475 .unwrap();
3476 cx.run_until_parked();
3477
3478 // Before any response, tokens_before_message should return None for first message
3479 thread.read_with(cx, |thread, _| {
3480 assert_eq!(
3481 thread.tokens_before_message(&message_1_id),
3482 None,
3483 "First message should have no tokens before it"
3484 );
3485 });
3486
3487 // Complete first message with usage
3488 fake_model.send_last_completion_stream_text_chunk("Response 1");
3489 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3490 language_model::TokenUsage {
3491 input_tokens: 100,
3492 output_tokens: 50,
3493 cache_creation_input_tokens: 0,
3494 cache_read_input_tokens: 0,
3495 },
3496 ));
3497 fake_model.end_last_completion_stream();
3498 cx.run_until_parked();
3499
3500 // First message still has no tokens before it
3501 thread.read_with(cx, |thread, _| {
3502 assert_eq!(
3503 thread.tokens_before_message(&message_1_id),
3504 None,
3505 "First message should still have no tokens before it after response"
3506 );
3507 });
3508
3509 // Second message
3510 let message_2_id = UserMessageId::new();
3511 thread
3512 .update(cx, |thread, cx| {
3513 thread.send(message_2_id.clone(), ["Second message"], cx)
3514 })
3515 .unwrap();
3516 cx.run_until_parked();
3517
3518 // Second message should have first message's input tokens before it
3519 thread.read_with(cx, |thread, _| {
3520 assert_eq!(
3521 thread.tokens_before_message(&message_2_id),
3522 Some(100),
3523 "Second message should have 100 tokens before it (from first request)"
3524 );
3525 });
3526
3527 // Complete second message
3528 fake_model.send_last_completion_stream_text_chunk("Response 2");
3529 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3530 language_model::TokenUsage {
3531 input_tokens: 250, // Total for this request (includes previous context)
3532 output_tokens: 75,
3533 cache_creation_input_tokens: 0,
3534 cache_read_input_tokens: 0,
3535 },
3536 ));
3537 fake_model.end_last_completion_stream();
3538 cx.run_until_parked();
3539
3540 // Third message
3541 let message_3_id = UserMessageId::new();
3542 thread
3543 .update(cx, |thread, cx| {
3544 thread.send(message_3_id.clone(), ["Third message"], cx)
3545 })
3546 .unwrap();
3547 cx.run_until_parked();
3548
3549 // Third message should have second message's input tokens (250) before it
3550 thread.read_with(cx, |thread, _| {
3551 assert_eq!(
3552 thread.tokens_before_message(&message_3_id),
3553 Some(250),
3554 "Third message should have 250 tokens before it (from second request)"
3555 );
3556 // Second message should still have 100
3557 assert_eq!(
3558 thread.tokens_before_message(&message_2_id),
3559 Some(100),
3560 "Second message should still have 100 tokens before it"
3561 );
3562 // First message still has none
3563 assert_eq!(
3564 thread.tokens_before_message(&message_1_id),
3565 None,
3566 "First message should still have no tokens before it"
3567 );
3568 });
3569}
3570
3571#[gpui::test]
3572async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3573 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3574 let fake_model = model.as_fake();
3575
3576 // Set up three messages with responses
3577 let message_1_id = UserMessageId::new();
3578 thread
3579 .update(cx, |thread, cx| {
3580 thread.send(message_1_id.clone(), ["Message 1"], cx)
3581 })
3582 .unwrap();
3583 cx.run_until_parked();
3584 fake_model.send_last_completion_stream_text_chunk("Response 1");
3585 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3586 language_model::TokenUsage {
3587 input_tokens: 100,
3588 output_tokens: 50,
3589 cache_creation_input_tokens: 0,
3590 cache_read_input_tokens: 0,
3591 },
3592 ));
3593 fake_model.end_last_completion_stream();
3594 cx.run_until_parked();
3595
3596 let message_2_id = UserMessageId::new();
3597 thread
3598 .update(cx, |thread, cx| {
3599 thread.send(message_2_id.clone(), ["Message 2"], cx)
3600 })
3601 .unwrap();
3602 cx.run_until_parked();
3603 fake_model.send_last_completion_stream_text_chunk("Response 2");
3604 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3605 language_model::TokenUsage {
3606 input_tokens: 250,
3607 output_tokens: 75,
3608 cache_creation_input_tokens: 0,
3609 cache_read_input_tokens: 0,
3610 },
3611 ));
3612 fake_model.end_last_completion_stream();
3613 cx.run_until_parked();
3614
3615 // Verify initial state
3616 thread.read_with(cx, |thread, _| {
3617 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3618 });
3619
3620 // Truncate at message 2 (removes message 2 and everything after)
3621 thread
3622 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3623 .unwrap();
3624 cx.run_until_parked();
3625
3626 // After truncation, message_2_id no longer exists, so lookup should return None
3627 thread.read_with(cx, |thread, _| {
3628 assert_eq!(
3629 thread.tokens_before_message(&message_2_id),
3630 None,
3631 "After truncation, message 2 no longer exists"
3632 );
3633 // Message 1 still exists but has no tokens before it
3634 assert_eq!(
3635 thread.tokens_before_message(&message_1_id),
3636 None,
3637 "First message still has no tokens before it"
3638 );
3639 });
3640}
3641
3642#[gpui::test]
3643async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3644 init_test(cx);
3645
3646 let fs = FakeFs::new(cx.executor());
3647 fs.insert_tree("/root", json!({})).await;
3648 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3649
3650 // Test 1: Deny rule blocks command
3651 {
3652 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3653 let environment = Rc::new(FakeThreadEnvironment {
3654 handle: handle.clone(),
3655 });
3656
3657 cx.update(|cx| {
3658 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3659 settings.tool_permissions.tools.insert(
3660 "terminal".into(),
3661 agent_settings::ToolRules {
3662 default_mode: settings::ToolPermissionMode::Confirm,
3663 always_allow: vec![],
3664 always_deny: vec![
3665 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3666 ],
3667 always_confirm: vec![],
3668 invalid_patterns: vec![],
3669 },
3670 );
3671 agent_settings::AgentSettings::override_global(settings, cx);
3672 });
3673
3674 #[allow(clippy::arc_with_non_send_sync)]
3675 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3676 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3677
3678 let task = cx.update(|cx| {
3679 tool.run(
3680 crate::TerminalToolInput {
3681 command: "rm -rf /".to_string(),
3682 cd: ".".to_string(),
3683 timeout_ms: None,
3684 },
3685 event_stream,
3686 cx,
3687 )
3688 });
3689
3690 let result = task.await;
3691 assert!(
3692 result.is_err(),
3693 "expected command to be blocked by deny rule"
3694 );
3695 assert!(
3696 result.unwrap_err().to_string().contains("blocked"),
3697 "error should mention the command was blocked"
3698 );
3699 }
3700
3701 // Test 2: Allow rule skips confirmation (and overrides default_mode: Deny)
3702 {
3703 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3704 let environment = Rc::new(FakeThreadEnvironment {
3705 handle: handle.clone(),
3706 });
3707
3708 cx.update(|cx| {
3709 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3710 settings.always_allow_tool_actions = false;
3711 settings.tool_permissions.tools.insert(
3712 "terminal".into(),
3713 agent_settings::ToolRules {
3714 default_mode: settings::ToolPermissionMode::Deny,
3715 always_allow: vec![
3716 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
3717 ],
3718 always_deny: vec![],
3719 always_confirm: vec![],
3720 invalid_patterns: vec![],
3721 },
3722 );
3723 agent_settings::AgentSettings::override_global(settings, cx);
3724 });
3725
3726 #[allow(clippy::arc_with_non_send_sync)]
3727 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3728 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3729
3730 let task = cx.update(|cx| {
3731 tool.run(
3732 crate::TerminalToolInput {
3733 command: "echo hello".to_string(),
3734 cd: ".".to_string(),
3735 timeout_ms: None,
3736 },
3737 event_stream,
3738 cx,
3739 )
3740 });
3741
3742 let update = rx.expect_update_fields().await;
3743 assert!(
3744 update.content.iter().any(|blocks| {
3745 blocks
3746 .iter()
3747 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
3748 }),
3749 "expected terminal content (allow rule should skip confirmation and override default deny)"
3750 );
3751
3752 let result = task.await;
3753 assert!(
3754 result.is_ok(),
3755 "expected command to succeed without confirmation"
3756 );
3757 }
3758
3759 // Test 3: always_allow_tool_actions=true overrides always_confirm patterns
3760 {
3761 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3762 let environment = Rc::new(FakeThreadEnvironment {
3763 handle: handle.clone(),
3764 });
3765
3766 cx.update(|cx| {
3767 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3768 settings.always_allow_tool_actions = true;
3769 settings.tool_permissions.tools.insert(
3770 "terminal".into(),
3771 agent_settings::ToolRules {
3772 default_mode: settings::ToolPermissionMode::Allow,
3773 always_allow: vec![],
3774 always_deny: vec![],
3775 always_confirm: vec![
3776 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
3777 ],
3778 invalid_patterns: vec![],
3779 },
3780 );
3781 agent_settings::AgentSettings::override_global(settings, cx);
3782 });
3783
3784 #[allow(clippy::arc_with_non_send_sync)]
3785 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3786 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3787
3788 let task = cx.update(|cx| {
3789 tool.run(
3790 crate::TerminalToolInput {
3791 command: "sudo rm file".to_string(),
3792 cd: ".".to_string(),
3793 timeout_ms: None,
3794 },
3795 event_stream,
3796 cx,
3797 )
3798 });
3799
3800 // With always_allow_tool_actions=true, confirm patterns are overridden
3801 task.await
3802 .expect("command should be allowed with always_allow_tool_actions=true");
3803 }
3804
3805 // Test 4: always_allow_tool_actions=true overrides default_mode: Deny
3806 {
3807 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3808 let environment = Rc::new(FakeThreadEnvironment {
3809 handle: handle.clone(),
3810 });
3811
3812 cx.update(|cx| {
3813 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3814 settings.always_allow_tool_actions = true;
3815 settings.tool_permissions.tools.insert(
3816 "terminal".into(),
3817 agent_settings::ToolRules {
3818 default_mode: settings::ToolPermissionMode::Deny,
3819 always_allow: vec![],
3820 always_deny: vec![],
3821 always_confirm: vec![],
3822 invalid_patterns: vec![],
3823 },
3824 );
3825 agent_settings::AgentSettings::override_global(settings, cx);
3826 });
3827
3828 #[allow(clippy::arc_with_non_send_sync)]
3829 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3830 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3831
3832 let task = cx.update(|cx| {
3833 tool.run(
3834 crate::TerminalToolInput {
3835 command: "echo hello".to_string(),
3836 cd: ".".to_string(),
3837 timeout_ms: None,
3838 },
3839 event_stream,
3840 cx,
3841 )
3842 });
3843
3844 // With always_allow_tool_actions=true, even default_mode: Deny is overridden
3845 task.await
3846 .expect("command should be allowed with always_allow_tool_actions=true");
3847 }
3848}
3849
3850#[gpui::test]
3851async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
3852 init_test(cx);
3853
3854 cx.update(|cx| {
3855 cx.update_flags(true, vec!["subagents".to_string()]);
3856 });
3857
3858 let fs = FakeFs::new(cx.executor());
3859 fs.insert_tree(path!("/test"), json!({})).await;
3860 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3861 let project_context = cx.new(|_cx| ProjectContext::default());
3862 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3863 let context_server_registry =
3864 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3865 let model = Arc::new(FakeLanguageModel::default());
3866
3867 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3868 let environment = Rc::new(FakeThreadEnvironment { handle });
3869
3870 let thread = cx.new(|cx| {
3871 let mut thread = Thread::new(
3872 project.clone(),
3873 project_context,
3874 context_server_registry,
3875 Templates::new(),
3876 Some(model),
3877 cx,
3878 );
3879 thread.add_default_tools(environment, cx);
3880 thread
3881 });
3882
3883 thread.read_with(cx, |thread, _| {
3884 assert!(
3885 thread.has_registered_tool("subagent"),
3886 "subagent tool should be present when feature flag is enabled"
3887 );
3888 });
3889}
3890
3891#[gpui::test]
3892async fn test_subagent_thread_inherits_parent_model(cx: &mut TestAppContext) {
3893 init_test(cx);
3894
3895 cx.update(|cx| {
3896 cx.update_flags(true, vec!["subagents".to_string()]);
3897 });
3898
3899 let fs = FakeFs::new(cx.executor());
3900 fs.insert_tree(path!("/test"), json!({})).await;
3901 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3902 let project_context = cx.new(|_cx| ProjectContext::default());
3903 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3904 let context_server_registry =
3905 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3906 let model = Arc::new(FakeLanguageModel::default());
3907
3908 let subagent_context = SubagentContext {
3909 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3910 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3911 depth: 1,
3912 summary_prompt: "Summarize".to_string(),
3913 context_low_prompt: "Context low".to_string(),
3914 };
3915
3916 let subagent = cx.new(|cx| {
3917 Thread::new_subagent(
3918 project.clone(),
3919 project_context,
3920 context_server_registry,
3921 Templates::new(),
3922 model.clone(),
3923 subagent_context,
3924 std::collections::BTreeMap::new(),
3925 cx,
3926 )
3927 });
3928
3929 subagent.read_with(cx, |thread, _| {
3930 assert!(thread.is_subagent());
3931 assert_eq!(thread.depth(), 1);
3932 assert!(thread.model().is_some());
3933 });
3934}
3935
3936#[gpui::test]
3937async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
3938 init_test(cx);
3939
3940 cx.update(|cx| {
3941 cx.update_flags(true, vec!["subagents".to_string()]);
3942 });
3943
3944 let fs = FakeFs::new(cx.executor());
3945 fs.insert_tree(path!("/test"), json!({})).await;
3946 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3947 let project_context = cx.new(|_cx| ProjectContext::default());
3948 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3949 let context_server_registry =
3950 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3951 let model = Arc::new(FakeLanguageModel::default());
3952
3953 let subagent_context = SubagentContext {
3954 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3955 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3956 depth: MAX_SUBAGENT_DEPTH,
3957 summary_prompt: "Summarize".to_string(),
3958 context_low_prompt: "Context low".to_string(),
3959 };
3960
3961 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3962 let environment = Rc::new(FakeThreadEnvironment { handle });
3963
3964 let deep_subagent = cx.new(|cx| {
3965 let mut thread = Thread::new_subagent(
3966 project.clone(),
3967 project_context,
3968 context_server_registry,
3969 Templates::new(),
3970 model.clone(),
3971 subagent_context,
3972 std::collections::BTreeMap::new(),
3973 cx,
3974 );
3975 thread.add_default_tools(environment, cx);
3976 thread
3977 });
3978
3979 deep_subagent.read_with(cx, |thread, _| {
3980 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
3981 assert!(
3982 !thread.has_registered_tool("subagent"),
3983 "subagent tool should not be present at max depth"
3984 );
3985 });
3986}
3987
3988#[gpui::test]
3989async fn test_subagent_receives_task_prompt(cx: &mut TestAppContext) {
3990 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3991 let fake_model = model.as_fake();
3992
3993 cx.update(|cx| {
3994 cx.update_flags(true, vec!["subagents".to_string()]);
3995 });
3996
3997 let subagent_context = SubagentContext {
3998 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3999 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4000 depth: 1,
4001 summary_prompt: "Summarize your work".to_string(),
4002 context_low_prompt: "Context low, wrap up".to_string(),
4003 };
4004
4005 let project = thread.read_with(cx, |t, _| t.project.clone());
4006 let project_context = cx.new(|_cx| ProjectContext::default());
4007 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4008 let context_server_registry =
4009 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4010
4011 let subagent = cx.new(|cx| {
4012 Thread::new_subagent(
4013 project.clone(),
4014 project_context,
4015 context_server_registry,
4016 Templates::new(),
4017 model.clone(),
4018 subagent_context,
4019 std::collections::BTreeMap::new(),
4020 cx,
4021 )
4022 });
4023
4024 let task_prompt = "Find all TODO comments in the codebase";
4025 subagent
4026 .update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))
4027 .unwrap();
4028 cx.run_until_parked();
4029
4030 let pending = fake_model.pending_completions();
4031 assert_eq!(pending.len(), 1, "should have one pending completion");
4032
4033 let messages = &pending[0].messages;
4034 let user_messages: Vec<_> = messages
4035 .iter()
4036 .filter(|m| m.role == language_model::Role::User)
4037 .collect();
4038 assert_eq!(user_messages.len(), 1, "should have one user message");
4039
4040 let content = &user_messages[0].content[0];
4041 assert!(
4042 content.to_str().unwrap().contains("TODO"),
4043 "task prompt should be in user message"
4044 );
4045}
4046
4047#[gpui::test]
4048async fn test_subagent_returns_summary_on_completion(cx: &mut TestAppContext) {
4049 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4050 let fake_model = model.as_fake();
4051
4052 cx.update(|cx| {
4053 cx.update_flags(true, vec!["subagents".to_string()]);
4054 });
4055
4056 let subagent_context = SubagentContext {
4057 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4058 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4059 depth: 1,
4060 summary_prompt: "Please summarize what you found".to_string(),
4061 context_low_prompt: "Context low, wrap up".to_string(),
4062 };
4063
4064 let project = thread.read_with(cx, |t, _| t.project.clone());
4065 let project_context = cx.new(|_cx| ProjectContext::default());
4066 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4067 let context_server_registry =
4068 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4069
4070 let subagent = cx.new(|cx| {
4071 Thread::new_subagent(
4072 project.clone(),
4073 project_context,
4074 context_server_registry,
4075 Templates::new(),
4076 model.clone(),
4077 subagent_context,
4078 std::collections::BTreeMap::new(),
4079 cx,
4080 )
4081 });
4082
4083 subagent
4084 .update(cx, |thread, cx| {
4085 thread.submit_user_message("Do some work", cx)
4086 })
4087 .unwrap();
4088 cx.run_until_parked();
4089
4090 fake_model.send_last_completion_stream_text_chunk("I did the work");
4091 fake_model
4092 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4093 fake_model.end_last_completion_stream();
4094 cx.run_until_parked();
4095
4096 subagent
4097 .update(cx, |thread, cx| thread.request_final_summary(cx))
4098 .unwrap();
4099 cx.run_until_parked();
4100
4101 let pending = fake_model.pending_completions();
4102 assert!(
4103 !pending.is_empty(),
4104 "should have pending completion for summary"
4105 );
4106
4107 let messages = &pending.last().unwrap().messages;
4108 let user_messages: Vec<_> = messages
4109 .iter()
4110 .filter(|m| m.role == language_model::Role::User)
4111 .collect();
4112
4113 let last_user = user_messages.last().unwrap();
4114 assert!(
4115 last_user.content[0].to_str().unwrap().contains("summarize"),
4116 "summary prompt should be sent"
4117 );
4118}
4119
4120#[gpui::test]
4121async fn test_allowed_tools_restricts_subagent_capabilities(cx: &mut TestAppContext) {
4122 init_test(cx);
4123
4124 cx.update(|cx| {
4125 cx.update_flags(true, vec!["subagents".to_string()]);
4126 });
4127
4128 let fs = FakeFs::new(cx.executor());
4129 fs.insert_tree(path!("/test"), json!({})).await;
4130 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4131 let project_context = cx.new(|_cx| ProjectContext::default());
4132 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4133 let context_server_registry =
4134 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4135 let model = Arc::new(FakeLanguageModel::default());
4136
4137 let subagent_context = SubagentContext {
4138 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4139 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4140 depth: 1,
4141 summary_prompt: "Summarize".to_string(),
4142 context_low_prompt: "Context low".to_string(),
4143 };
4144
4145 let subagent = cx.new(|cx| {
4146 let mut thread = Thread::new_subagent(
4147 project.clone(),
4148 project_context,
4149 context_server_registry,
4150 Templates::new(),
4151 model.clone(),
4152 subagent_context,
4153 std::collections::BTreeMap::new(),
4154 cx,
4155 );
4156 thread.add_tool(EchoTool);
4157 thread.add_tool(DelayTool);
4158 thread.add_tool(WordListTool);
4159 thread
4160 });
4161
4162 subagent.read_with(cx, |thread, _| {
4163 assert!(thread.has_registered_tool("echo"));
4164 assert!(thread.has_registered_tool("delay"));
4165 assert!(thread.has_registered_tool("word_list"));
4166 });
4167
4168 let allowed: collections::HashSet<gpui::SharedString> =
4169 vec!["echo".into()].into_iter().collect();
4170
4171 subagent.update(cx, |thread, _cx| {
4172 thread.restrict_tools(&allowed);
4173 });
4174
4175 subagent.read_with(cx, |thread, _| {
4176 assert!(
4177 thread.has_registered_tool("echo"),
4178 "echo should still be available"
4179 );
4180 assert!(
4181 !thread.has_registered_tool("delay"),
4182 "delay should be removed"
4183 );
4184 assert!(
4185 !thread.has_registered_tool("word_list"),
4186 "word_list should be removed"
4187 );
4188 });
4189}
4190
4191#[gpui::test]
4192async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4193 init_test(cx);
4194
4195 cx.update(|cx| {
4196 cx.update_flags(true, vec!["subagents".to_string()]);
4197 });
4198
4199 let fs = FakeFs::new(cx.executor());
4200 fs.insert_tree(path!("/test"), json!({})).await;
4201 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4202 let project_context = cx.new(|_cx| ProjectContext::default());
4203 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4204 let context_server_registry =
4205 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4206 let model = Arc::new(FakeLanguageModel::default());
4207
4208 let parent = cx.new(|cx| {
4209 Thread::new(
4210 project.clone(),
4211 project_context.clone(),
4212 context_server_registry.clone(),
4213 Templates::new(),
4214 Some(model.clone()),
4215 cx,
4216 )
4217 });
4218
4219 let subagent_context = SubagentContext {
4220 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4221 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4222 depth: 1,
4223 summary_prompt: "Summarize".to_string(),
4224 context_low_prompt: "Context low".to_string(),
4225 };
4226
4227 let subagent = cx.new(|cx| {
4228 Thread::new_subagent(
4229 project.clone(),
4230 project_context.clone(),
4231 context_server_registry.clone(),
4232 Templates::new(),
4233 model.clone(),
4234 subagent_context,
4235 std::collections::BTreeMap::new(),
4236 cx,
4237 )
4238 });
4239
4240 parent.update(cx, |thread, _cx| {
4241 thread.register_running_subagent(subagent.downgrade());
4242 });
4243
4244 subagent
4245 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4246 .unwrap();
4247 cx.run_until_parked();
4248
4249 subagent.read_with(cx, |thread, _| {
4250 assert!(!thread.is_turn_complete(), "subagent should be running");
4251 });
4252
4253 parent.update(cx, |thread, cx| {
4254 thread.cancel(cx).detach();
4255 });
4256
4257 subagent.read_with(cx, |thread, _| {
4258 assert!(
4259 thread.is_turn_complete(),
4260 "subagent should be cancelled when parent cancels"
4261 );
4262 });
4263}
4264
4265#[gpui::test]
4266async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4267 // This test verifies that the subagent tool properly handles user cancellation
4268 // via `event_stream.cancelled_by_user()` and stops all running subagents.
4269 init_test(cx);
4270 always_allow_tools(cx);
4271
4272 cx.update(|cx| {
4273 cx.update_flags(true, vec!["subagents".to_string()]);
4274 });
4275
4276 let fs = FakeFs::new(cx.executor());
4277 fs.insert_tree(path!("/test"), json!({})).await;
4278 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4279 let project_context = cx.new(|_cx| ProjectContext::default());
4280 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4281 let context_server_registry =
4282 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4283 let model = Arc::new(FakeLanguageModel::default());
4284
4285 let parent = cx.new(|cx| {
4286 Thread::new(
4287 project.clone(),
4288 project_context.clone(),
4289 context_server_registry.clone(),
4290 Templates::new(),
4291 Some(model.clone()),
4292 cx,
4293 )
4294 });
4295
4296 #[allow(clippy::arc_with_non_send_sync)]
4297 let tool = Arc::new(SubagentTool::new(parent.downgrade(), 0));
4298
4299 let (event_stream, _rx, mut cancellation_tx) =
4300 crate::ToolCallEventStream::test_with_cancellation();
4301
4302 // Start the subagent tool
4303 let task = cx.update(|cx| {
4304 tool.run(
4305 SubagentToolInput {
4306 label: "Long running task".to_string(),
4307 task_prompt: "Do a very long task that takes forever".to_string(),
4308 summary_prompt: "Summarize".to_string(),
4309 context_low_prompt: "Context low".to_string(),
4310 timeout_ms: None,
4311 allowed_tools: None,
4312 },
4313 event_stream.clone(),
4314 cx,
4315 )
4316 });
4317
4318 cx.run_until_parked();
4319
4320 // Signal cancellation via the event stream
4321 crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4322
4323 // The task should complete promptly with a cancellation error
4324 let timeout = cx.background_executor.timer(Duration::from_secs(5));
4325 let result = futures::select! {
4326 result = task.fuse() => result,
4327 _ = timeout.fuse() => {
4328 panic!("subagent tool did not respond to cancellation within timeout");
4329 }
4330 };
4331
4332 // Verify we got a cancellation error
4333 let err = result.unwrap_err();
4334 assert!(
4335 err.to_string().contains("cancelled by user"),
4336 "expected cancellation error, got: {}",
4337 err
4338 );
4339}
4340
4341#[gpui::test]
4342async fn test_subagent_model_error_returned_as_tool_error(cx: &mut TestAppContext) {
4343 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4344 let fake_model = model.as_fake();
4345
4346 cx.update(|cx| {
4347 cx.update_flags(true, vec!["subagents".to_string()]);
4348 });
4349
4350 let subagent_context = SubagentContext {
4351 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4352 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4353 depth: 1,
4354 summary_prompt: "Summarize".to_string(),
4355 context_low_prompt: "Context low".to_string(),
4356 };
4357
4358 let project = thread.read_with(cx, |t, _| t.project.clone());
4359 let project_context = cx.new(|_cx| ProjectContext::default());
4360 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4361 let context_server_registry =
4362 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4363
4364 let subagent = cx.new(|cx| {
4365 Thread::new_subagent(
4366 project.clone(),
4367 project_context,
4368 context_server_registry,
4369 Templates::new(),
4370 model.clone(),
4371 subagent_context,
4372 std::collections::BTreeMap::new(),
4373 cx,
4374 )
4375 });
4376
4377 subagent
4378 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4379 .unwrap();
4380 cx.run_until_parked();
4381
4382 subagent.read_with(cx, |thread, _| {
4383 assert!(!thread.is_turn_complete(), "turn should be in progress");
4384 });
4385
4386 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::NoApiKey {
4387 provider: LanguageModelProviderName::from("Fake".to_string()),
4388 });
4389 fake_model.end_last_completion_stream();
4390 cx.run_until_parked();
4391
4392 subagent.read_with(cx, |thread, _| {
4393 assert!(
4394 thread.is_turn_complete(),
4395 "turn should be complete after non-retryable error"
4396 );
4397 });
4398}
4399
4400#[gpui::test]
4401async fn test_subagent_timeout_triggers_early_summary(cx: &mut TestAppContext) {
4402 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4403 let fake_model = model.as_fake();
4404
4405 cx.update(|cx| {
4406 cx.update_flags(true, vec!["subagents".to_string()]);
4407 });
4408
4409 let subagent_context = SubagentContext {
4410 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4411 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4412 depth: 1,
4413 summary_prompt: "Summarize your work".to_string(),
4414 context_low_prompt: "Context low, stop and summarize".to_string(),
4415 };
4416
4417 let project = thread.read_with(cx, |t, _| t.project.clone());
4418 let project_context = cx.new(|_cx| ProjectContext::default());
4419 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4420 let context_server_registry =
4421 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4422
4423 let subagent = cx.new(|cx| {
4424 Thread::new_subagent(
4425 project.clone(),
4426 project_context.clone(),
4427 context_server_registry.clone(),
4428 Templates::new(),
4429 model.clone(),
4430 subagent_context.clone(),
4431 std::collections::BTreeMap::new(),
4432 cx,
4433 )
4434 });
4435
4436 subagent.update(cx, |thread, _| {
4437 thread.add_tool(EchoTool);
4438 });
4439
4440 subagent
4441 .update(cx, |thread, cx| {
4442 thread.submit_user_message("Do some work", cx)
4443 })
4444 .unwrap();
4445 cx.run_until_parked();
4446
4447 fake_model.send_last_completion_stream_text_chunk("Working on it...");
4448 fake_model
4449 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4450 fake_model.end_last_completion_stream();
4451 cx.run_until_parked();
4452
4453 let interrupt_result = subagent.update(cx, |thread, cx| thread.interrupt_for_summary(cx));
4454 assert!(
4455 interrupt_result.is_ok(),
4456 "interrupt_for_summary should succeed"
4457 );
4458
4459 cx.run_until_parked();
4460
4461 let pending = fake_model.pending_completions();
4462 assert!(
4463 !pending.is_empty(),
4464 "should have pending completion for interrupted summary"
4465 );
4466
4467 let messages = &pending.last().unwrap().messages;
4468 let user_messages: Vec<_> = messages
4469 .iter()
4470 .filter(|m| m.role == language_model::Role::User)
4471 .collect();
4472
4473 let last_user = user_messages.last().unwrap();
4474 let content_str = last_user.content[0].to_str().unwrap();
4475 assert!(
4476 content_str.contains("Context low") || content_str.contains("stop and summarize"),
4477 "context_low_prompt should be sent when interrupting: got {:?}",
4478 content_str
4479 );
4480}
4481
4482#[gpui::test]
4483async fn test_context_low_check_returns_true_when_usage_high(cx: &mut TestAppContext) {
4484 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4485 let fake_model = model.as_fake();
4486
4487 cx.update(|cx| {
4488 cx.update_flags(true, vec!["subagents".to_string()]);
4489 });
4490
4491 let subagent_context = SubagentContext {
4492 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4493 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4494 depth: 1,
4495 summary_prompt: "Summarize".to_string(),
4496 context_low_prompt: "Context low".to_string(),
4497 };
4498
4499 let project = thread.read_with(cx, |t, _| t.project.clone());
4500 let project_context = cx.new(|_cx| ProjectContext::default());
4501 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4502 let context_server_registry =
4503 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4504
4505 let subagent = cx.new(|cx| {
4506 Thread::new_subagent(
4507 project.clone(),
4508 project_context,
4509 context_server_registry,
4510 Templates::new(),
4511 model.clone(),
4512 subagent_context,
4513 std::collections::BTreeMap::new(),
4514 cx,
4515 )
4516 });
4517
4518 subagent
4519 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4520 .unwrap();
4521 cx.run_until_parked();
4522
4523 let max_tokens = model.max_token_count();
4524 let high_usage = language_model::TokenUsage {
4525 input_tokens: (max_tokens as f64 * 0.80) as u64,
4526 output_tokens: 0,
4527 cache_creation_input_tokens: 0,
4528 cache_read_input_tokens: 0,
4529 };
4530
4531 fake_model
4532 .send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(high_usage));
4533 fake_model.send_last_completion_stream_text_chunk("Working...");
4534 fake_model
4535 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4536 fake_model.end_last_completion_stream();
4537 cx.run_until_parked();
4538
4539 let usage = subagent.read_with(cx, |thread, _| thread.latest_token_usage());
4540 assert!(usage.is_some(), "should have token usage after completion");
4541
4542 let usage = usage.unwrap();
4543 let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
4544 assert!(
4545 remaining_ratio <= 0.25,
4546 "remaining ratio should be at or below 25% (got {}%), indicating context is low",
4547 remaining_ratio * 100.0
4548 );
4549}
4550
4551#[gpui::test]
4552async fn test_allowed_tools_rejects_unknown_tool(cx: &mut TestAppContext) {
4553 init_test(cx);
4554
4555 cx.update(|cx| {
4556 cx.update_flags(true, vec!["subagents".to_string()]);
4557 });
4558
4559 let fs = FakeFs::new(cx.executor());
4560 fs.insert_tree(path!("/test"), json!({})).await;
4561 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4562 let project_context = cx.new(|_cx| ProjectContext::default());
4563 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4564 let context_server_registry =
4565 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4566 let model = Arc::new(FakeLanguageModel::default());
4567
4568 let parent = cx.new(|cx| {
4569 let mut thread = Thread::new(
4570 project.clone(),
4571 project_context.clone(),
4572 context_server_registry.clone(),
4573 Templates::new(),
4574 Some(model.clone()),
4575 cx,
4576 );
4577 thread.add_tool(EchoTool);
4578 thread
4579 });
4580
4581 #[allow(clippy::arc_with_non_send_sync)]
4582 let tool = Arc::new(SubagentTool::new(parent.downgrade(), 0));
4583
4584 let allowed_tools = Some(vec!["nonexistent_tool".to_string()]);
4585 let result = cx.read(|cx| tool.validate_allowed_tools(&allowed_tools, cx));
4586
4587 assert!(result.is_err(), "should reject unknown tool");
4588 let err_msg = result.unwrap_err().to_string();
4589 assert!(
4590 err_msg.contains("nonexistent_tool"),
4591 "error should mention the invalid tool name: {}",
4592 err_msg
4593 );
4594 assert!(
4595 err_msg.contains("do not exist"),
4596 "error should explain the tool does not exist: {}",
4597 err_msg
4598 );
4599}
4600
4601#[gpui::test]
4602async fn test_subagent_empty_response_handled(cx: &mut TestAppContext) {
4603 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4604 let fake_model = model.as_fake();
4605
4606 cx.update(|cx| {
4607 cx.update_flags(true, vec!["subagents".to_string()]);
4608 });
4609
4610 let subagent_context = SubagentContext {
4611 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4612 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4613 depth: 1,
4614 summary_prompt: "Summarize".to_string(),
4615 context_low_prompt: "Context low".to_string(),
4616 };
4617
4618 let project = thread.read_with(cx, |t, _| t.project.clone());
4619 let project_context = cx.new(|_cx| ProjectContext::default());
4620 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4621 let context_server_registry =
4622 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4623
4624 let subagent = cx.new(|cx| {
4625 Thread::new_subagent(
4626 project.clone(),
4627 project_context,
4628 context_server_registry,
4629 Templates::new(),
4630 model.clone(),
4631 subagent_context,
4632 std::collections::BTreeMap::new(),
4633 cx,
4634 )
4635 });
4636
4637 subagent
4638 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4639 .unwrap();
4640 cx.run_until_parked();
4641
4642 fake_model
4643 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4644 fake_model.end_last_completion_stream();
4645 cx.run_until_parked();
4646
4647 subagent.read_with(cx, |thread, _| {
4648 assert!(
4649 thread.is_turn_complete(),
4650 "turn should complete even with empty response"
4651 );
4652 });
4653}
4654
4655#[gpui::test]
4656async fn test_nested_subagent_at_depth_2_succeeds(cx: &mut TestAppContext) {
4657 init_test(cx);
4658
4659 cx.update(|cx| {
4660 cx.update_flags(true, vec!["subagents".to_string()]);
4661 });
4662
4663 let fs = FakeFs::new(cx.executor());
4664 fs.insert_tree(path!("/test"), json!({})).await;
4665 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4666 let project_context = cx.new(|_cx| ProjectContext::default());
4667 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4668 let context_server_registry =
4669 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4670 let model = Arc::new(FakeLanguageModel::default());
4671
4672 let depth_1_context = SubagentContext {
4673 parent_thread_id: agent_client_protocol::SessionId::new("root-id"),
4674 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-1"),
4675 depth: 1,
4676 summary_prompt: "Summarize".to_string(),
4677 context_low_prompt: "Context low".to_string(),
4678 };
4679
4680 let depth_1_subagent = cx.new(|cx| {
4681 Thread::new_subagent(
4682 project.clone(),
4683 project_context.clone(),
4684 context_server_registry.clone(),
4685 Templates::new(),
4686 model.clone(),
4687 depth_1_context,
4688 std::collections::BTreeMap::new(),
4689 cx,
4690 )
4691 });
4692
4693 depth_1_subagent.read_with(cx, |thread, _| {
4694 assert_eq!(thread.depth(), 1);
4695 assert!(thread.is_subagent());
4696 });
4697
4698 let depth_2_context = SubagentContext {
4699 parent_thread_id: agent_client_protocol::SessionId::new("depth-1-id"),
4700 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-2"),
4701 depth: 2,
4702 summary_prompt: "Summarize depth 2".to_string(),
4703 context_low_prompt: "Context low depth 2".to_string(),
4704 };
4705
4706 let depth_2_subagent = cx.new(|cx| {
4707 Thread::new_subagent(
4708 project.clone(),
4709 project_context.clone(),
4710 context_server_registry.clone(),
4711 Templates::new(),
4712 model.clone(),
4713 depth_2_context,
4714 std::collections::BTreeMap::new(),
4715 cx,
4716 )
4717 });
4718
4719 depth_2_subagent.read_with(cx, |thread, _| {
4720 assert_eq!(thread.depth(), 2);
4721 assert!(thread.is_subagent());
4722 });
4723
4724 depth_2_subagent
4725 .update(cx, |thread, cx| {
4726 thread.submit_user_message("Nested task", cx)
4727 })
4728 .unwrap();
4729 cx.run_until_parked();
4730
4731 let pending = model.as_fake().pending_completions();
4732 assert!(
4733 !pending.is_empty(),
4734 "depth-2 subagent should be able to submit messages"
4735 );
4736}
4737
4738#[gpui::test]
4739async fn test_subagent_uses_tool_and_returns_result(cx: &mut TestAppContext) {
4740 init_test(cx);
4741 always_allow_tools(cx);
4742
4743 cx.update(|cx| {
4744 cx.update_flags(true, vec!["subagents".to_string()]);
4745 });
4746
4747 let fs = FakeFs::new(cx.executor());
4748 fs.insert_tree(path!("/test"), json!({})).await;
4749 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4750 let project_context = cx.new(|_cx| ProjectContext::default());
4751 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4752 let context_server_registry =
4753 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4754 let model = Arc::new(FakeLanguageModel::default());
4755 let fake_model = model.as_fake();
4756
4757 let subagent_context = SubagentContext {
4758 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4759 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4760 depth: 1,
4761 summary_prompt: "Summarize what you did".to_string(),
4762 context_low_prompt: "Context low".to_string(),
4763 };
4764
4765 let subagent = cx.new(|cx| {
4766 let mut thread = Thread::new_subagent(
4767 project.clone(),
4768 project_context,
4769 context_server_registry,
4770 Templates::new(),
4771 model.clone(),
4772 subagent_context,
4773 std::collections::BTreeMap::new(),
4774 cx,
4775 );
4776 thread.add_tool(EchoTool);
4777 thread
4778 });
4779
4780 subagent.read_with(cx, |thread, _| {
4781 assert!(
4782 thread.has_registered_tool("echo"),
4783 "subagent should have echo tool"
4784 );
4785 });
4786
4787 subagent
4788 .update(cx, |thread, cx| {
4789 thread.submit_user_message("Use the echo tool to echo 'hello world'", cx)
4790 })
4791 .unwrap();
4792 cx.run_until_parked();
4793
4794 let tool_use = LanguageModelToolUse {
4795 id: "tool_call_1".into(),
4796 name: EchoTool::name().into(),
4797 raw_input: json!({"text": "hello world"}).to_string(),
4798 input: json!({"text": "hello world"}),
4799 is_input_complete: true,
4800 thought_signature: None,
4801 };
4802 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
4803 fake_model.end_last_completion_stream();
4804 cx.run_until_parked();
4805
4806 let pending = fake_model.pending_completions();
4807 assert!(
4808 !pending.is_empty(),
4809 "should have pending completion after tool use"
4810 );
4811
4812 let last_completion = pending.last().unwrap();
4813 let has_tool_result = last_completion.messages.iter().any(|m| {
4814 m.content
4815 .iter()
4816 .any(|c| matches!(c, MessageContent::ToolResult(_)))
4817 });
4818 assert!(
4819 has_tool_result,
4820 "tool result should be in the messages sent back to the model"
4821 );
4822}
4823
4824#[gpui::test]
4825async fn test_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4826 init_test(cx);
4827
4828 cx.update(|cx| {
4829 cx.update_flags(true, vec!["subagents".to_string()]);
4830 });
4831
4832 let fs = FakeFs::new(cx.executor());
4833 fs.insert_tree(path!("/test"), json!({})).await;
4834 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4835 let project_context = cx.new(|_cx| ProjectContext::default());
4836 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4837 let context_server_registry =
4838 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4839 let model = Arc::new(FakeLanguageModel::default());
4840
4841 let parent = cx.new(|cx| {
4842 Thread::new(
4843 project.clone(),
4844 project_context.clone(),
4845 context_server_registry.clone(),
4846 Templates::new(),
4847 Some(model.clone()),
4848 cx,
4849 )
4850 });
4851
4852 let mut subagents = Vec::new();
4853 for i in 0..MAX_PARALLEL_SUBAGENTS {
4854 let subagent_context = SubagentContext {
4855 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4856 tool_use_id: language_model::LanguageModelToolUseId::from(format!("tool-use-{}", i)),
4857 depth: 1,
4858 summary_prompt: "Summarize".to_string(),
4859 context_low_prompt: "Context low".to_string(),
4860 };
4861
4862 let subagent = cx.new(|cx| {
4863 Thread::new_subagent(
4864 project.clone(),
4865 project_context.clone(),
4866 context_server_registry.clone(),
4867 Templates::new(),
4868 model.clone(),
4869 subagent_context,
4870 std::collections::BTreeMap::new(),
4871 cx,
4872 )
4873 });
4874
4875 parent.update(cx, |thread, _cx| {
4876 thread.register_running_subagent(subagent.downgrade());
4877 });
4878 subagents.push(subagent);
4879 }
4880
4881 parent.read_with(cx, |thread, _| {
4882 assert_eq!(
4883 thread.running_subagent_count(),
4884 MAX_PARALLEL_SUBAGENTS,
4885 "should have MAX_PARALLEL_SUBAGENTS registered"
4886 );
4887 });
4888
4889 #[allow(clippy::arc_with_non_send_sync)]
4890 let tool = Arc::new(SubagentTool::new(parent.downgrade(), 0));
4891
4892 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4893
4894 let result = cx.update(|cx| {
4895 tool.run(
4896 SubagentToolInput {
4897 label: "Test".to_string(),
4898 task_prompt: "Do something".to_string(),
4899 summary_prompt: "Summarize".to_string(),
4900 context_low_prompt: "Context low".to_string(),
4901 timeout_ms: None,
4902 allowed_tools: None,
4903 },
4904 event_stream,
4905 cx,
4906 )
4907 });
4908
4909 let err = result.await.unwrap_err();
4910 assert!(
4911 err.to_string().contains("Maximum parallel subagents"),
4912 "should reject when max parallel subagents reached: {}",
4913 err
4914 );
4915
4916 drop(subagents);
4917}
4918
4919#[gpui::test]
4920async fn test_subagent_tool_end_to_end(cx: &mut TestAppContext) {
4921 init_test(cx);
4922 always_allow_tools(cx);
4923
4924 cx.update(|cx| {
4925 cx.update_flags(true, vec!["subagents".to_string()]);
4926 });
4927
4928 let fs = FakeFs::new(cx.executor());
4929 fs.insert_tree(path!("/test"), json!({})).await;
4930 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4931 let project_context = cx.new(|_cx| ProjectContext::default());
4932 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4933 let context_server_registry =
4934 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4935 let model = Arc::new(FakeLanguageModel::default());
4936 let fake_model = model.as_fake();
4937
4938 let parent = cx.new(|cx| {
4939 let mut thread = Thread::new(
4940 project.clone(),
4941 project_context.clone(),
4942 context_server_registry.clone(),
4943 Templates::new(),
4944 Some(model.clone()),
4945 cx,
4946 );
4947 thread.add_tool(EchoTool);
4948 thread
4949 });
4950
4951 #[allow(clippy::arc_with_non_send_sync)]
4952 let tool = Arc::new(SubagentTool::new(parent.downgrade(), 0));
4953
4954 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4955
4956 let task = cx.update(|cx| {
4957 tool.run(
4958 SubagentToolInput {
4959 label: "Research task".to_string(),
4960 task_prompt: "Find all TODOs in the codebase".to_string(),
4961 summary_prompt: "Summarize what you found".to_string(),
4962 context_low_prompt: "Context low, wrap up".to_string(),
4963 timeout_ms: None,
4964 allowed_tools: None,
4965 },
4966 event_stream,
4967 cx,
4968 )
4969 });
4970
4971 cx.run_until_parked();
4972
4973 let pending = fake_model.pending_completions();
4974 assert!(
4975 !pending.is_empty(),
4976 "subagent should have started and sent a completion request"
4977 );
4978
4979 let first_completion = &pending[0];
4980 let has_task_prompt = first_completion.messages.iter().any(|m| {
4981 m.role == language_model::Role::User
4982 && m.content
4983 .iter()
4984 .any(|c| c.to_str().map(|s| s.contains("TODO")).unwrap_or(false))
4985 });
4986 assert!(has_task_prompt, "task prompt should be sent to subagent");
4987
4988 fake_model.send_last_completion_stream_text_chunk("I found 5 TODOs in the codebase.");
4989 fake_model
4990 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4991 fake_model.end_last_completion_stream();
4992 cx.run_until_parked();
4993
4994 let pending = fake_model.pending_completions();
4995 assert!(
4996 !pending.is_empty(),
4997 "should have pending completion for summary request"
4998 );
4999
5000 let last_completion = pending.last().unwrap();
5001 let has_summary_prompt = last_completion.messages.iter().any(|m| {
5002 m.role == language_model::Role::User
5003 && m.content.iter().any(|c| {
5004 c.to_str()
5005 .map(|s| s.contains("Summarize") || s.contains("summarize"))
5006 .unwrap_or(false)
5007 })
5008 });
5009 assert!(
5010 has_summary_prompt,
5011 "summary prompt should be sent after task completion"
5012 );
5013
5014 fake_model.send_last_completion_stream_text_chunk("Summary: Found 5 TODOs across 3 files.");
5015 fake_model
5016 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
5017 fake_model.end_last_completion_stream();
5018 cx.run_until_parked();
5019
5020 let result = task.await;
5021 assert!(result.is_ok(), "subagent tool should complete successfully");
5022
5023 let summary = result.unwrap();
5024 assert!(
5025 summary.contains("Summary") || summary.contains("TODO") || summary.contains("5"),
5026 "summary should contain subagent's response: {}",
5027 summary
5028 );
5029}
5030
5031#[gpui::test]
5032async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5033 init_test(cx);
5034
5035 let fs = FakeFs::new(cx.executor());
5036 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5037 .await;
5038 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5039
5040 cx.update(|cx| {
5041 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5042 settings.tool_permissions.tools.insert(
5043 "edit_file".into(),
5044 agent_settings::ToolRules {
5045 default_mode: settings::ToolPermissionMode::Allow,
5046 always_allow: vec![],
5047 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5048 always_confirm: vec![],
5049 invalid_patterns: vec![],
5050 },
5051 );
5052 agent_settings::AgentSettings::override_global(settings, cx);
5053 });
5054
5055 let context_server_registry =
5056 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5057 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5058 let templates = crate::Templates::new();
5059 let thread = cx.new(|cx| {
5060 crate::Thread::new(
5061 project.clone(),
5062 cx.new(|_cx| prompt_store::ProjectContext::default()),
5063 context_server_registry,
5064 templates.clone(),
5065 None,
5066 cx,
5067 )
5068 });
5069
5070 #[allow(clippy::arc_with_non_send_sync)]
5071 let tool = Arc::new(crate::EditFileTool::new(
5072 project.clone(),
5073 thread.downgrade(),
5074 language_registry,
5075 templates,
5076 ));
5077 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5078
5079 let task = cx.update(|cx| {
5080 tool.run(
5081 crate::EditFileToolInput {
5082 display_description: "Edit sensitive file".to_string(),
5083 path: "root/sensitive_config.txt".into(),
5084 mode: crate::EditFileMode::Edit,
5085 },
5086 event_stream,
5087 cx,
5088 )
5089 });
5090
5091 let result = task.await;
5092 assert!(result.is_err(), "expected edit to be blocked");
5093 assert!(
5094 result.unwrap_err().to_string().contains("blocked"),
5095 "error should mention the edit was blocked"
5096 );
5097}
5098
5099#[gpui::test]
5100async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5101 init_test(cx);
5102
5103 let fs = FakeFs::new(cx.executor());
5104 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5105 .await;
5106 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5107
5108 cx.update(|cx| {
5109 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5110 settings.tool_permissions.tools.insert(
5111 "delete_path".into(),
5112 agent_settings::ToolRules {
5113 default_mode: settings::ToolPermissionMode::Allow,
5114 always_allow: vec![],
5115 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5116 always_confirm: vec![],
5117 invalid_patterns: vec![],
5118 },
5119 );
5120 agent_settings::AgentSettings::override_global(settings, cx);
5121 });
5122
5123 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5124
5125 #[allow(clippy::arc_with_non_send_sync)]
5126 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5127 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5128
5129 let task = cx.update(|cx| {
5130 tool.run(
5131 crate::DeletePathToolInput {
5132 path: "root/important_data.txt".to_string(),
5133 },
5134 event_stream,
5135 cx,
5136 )
5137 });
5138
5139 let result = task.await;
5140 assert!(result.is_err(), "expected deletion to be blocked");
5141 assert!(
5142 result.unwrap_err().to_string().contains("blocked"),
5143 "error should mention the deletion was blocked"
5144 );
5145}
5146
5147#[gpui::test]
5148async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5149 init_test(cx);
5150
5151 let fs = FakeFs::new(cx.executor());
5152 fs.insert_tree(
5153 "/root",
5154 json!({
5155 "safe.txt": "content",
5156 "protected": {}
5157 }),
5158 )
5159 .await;
5160 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5161
5162 cx.update(|cx| {
5163 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5164 settings.tool_permissions.tools.insert(
5165 "move_path".into(),
5166 agent_settings::ToolRules {
5167 default_mode: settings::ToolPermissionMode::Allow,
5168 always_allow: vec![],
5169 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5170 always_confirm: vec![],
5171 invalid_patterns: vec![],
5172 },
5173 );
5174 agent_settings::AgentSettings::override_global(settings, cx);
5175 });
5176
5177 #[allow(clippy::arc_with_non_send_sync)]
5178 let tool = Arc::new(crate::MovePathTool::new(project));
5179 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5180
5181 let task = cx.update(|cx| {
5182 tool.run(
5183 crate::MovePathToolInput {
5184 source_path: "root/safe.txt".to_string(),
5185 destination_path: "root/protected/safe.txt".to_string(),
5186 },
5187 event_stream,
5188 cx,
5189 )
5190 });
5191
5192 let result = task.await;
5193 assert!(
5194 result.is_err(),
5195 "expected move to be blocked due to destination path"
5196 );
5197 assert!(
5198 result.unwrap_err().to_string().contains("blocked"),
5199 "error should mention the move was blocked"
5200 );
5201}
5202
5203#[gpui::test]
5204async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5205 init_test(cx);
5206
5207 let fs = FakeFs::new(cx.executor());
5208 fs.insert_tree(
5209 "/root",
5210 json!({
5211 "secret.txt": "secret content",
5212 "public": {}
5213 }),
5214 )
5215 .await;
5216 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5217
5218 cx.update(|cx| {
5219 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5220 settings.tool_permissions.tools.insert(
5221 "move_path".into(),
5222 agent_settings::ToolRules {
5223 default_mode: settings::ToolPermissionMode::Allow,
5224 always_allow: vec![],
5225 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5226 always_confirm: vec![],
5227 invalid_patterns: vec![],
5228 },
5229 );
5230 agent_settings::AgentSettings::override_global(settings, cx);
5231 });
5232
5233 #[allow(clippy::arc_with_non_send_sync)]
5234 let tool = Arc::new(crate::MovePathTool::new(project));
5235 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5236
5237 let task = cx.update(|cx| {
5238 tool.run(
5239 crate::MovePathToolInput {
5240 source_path: "root/secret.txt".to_string(),
5241 destination_path: "root/public/not_secret.txt".to_string(),
5242 },
5243 event_stream,
5244 cx,
5245 )
5246 });
5247
5248 let result = task.await;
5249 assert!(
5250 result.is_err(),
5251 "expected move to be blocked due to source path"
5252 );
5253 assert!(
5254 result.unwrap_err().to_string().contains("blocked"),
5255 "error should mention the move was blocked"
5256 );
5257}
5258
5259#[gpui::test]
5260async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5261 init_test(cx);
5262
5263 let fs = FakeFs::new(cx.executor());
5264 fs.insert_tree(
5265 "/root",
5266 json!({
5267 "confidential.txt": "confidential data",
5268 "dest": {}
5269 }),
5270 )
5271 .await;
5272 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5273
5274 cx.update(|cx| {
5275 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5276 settings.tool_permissions.tools.insert(
5277 "copy_path".into(),
5278 agent_settings::ToolRules {
5279 default_mode: settings::ToolPermissionMode::Allow,
5280 always_allow: vec![],
5281 always_deny: vec![
5282 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5283 ],
5284 always_confirm: vec![],
5285 invalid_patterns: vec![],
5286 },
5287 );
5288 agent_settings::AgentSettings::override_global(settings, cx);
5289 });
5290
5291 #[allow(clippy::arc_with_non_send_sync)]
5292 let tool = Arc::new(crate::CopyPathTool::new(project));
5293 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5294
5295 let task = cx.update(|cx| {
5296 tool.run(
5297 crate::CopyPathToolInput {
5298 source_path: "root/confidential.txt".to_string(),
5299 destination_path: "root/dest/copy.txt".to_string(),
5300 },
5301 event_stream,
5302 cx,
5303 )
5304 });
5305
5306 let result = task.await;
5307 assert!(result.is_err(), "expected copy to be blocked");
5308 assert!(
5309 result.unwrap_err().to_string().contains("blocked"),
5310 "error should mention the copy was blocked"
5311 );
5312}
5313
5314#[gpui::test]
5315async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5316 init_test(cx);
5317
5318 let fs = FakeFs::new(cx.executor());
5319 fs.insert_tree(
5320 "/root",
5321 json!({
5322 "normal.txt": "normal content",
5323 "readonly": {
5324 "config.txt": "readonly content"
5325 }
5326 }),
5327 )
5328 .await;
5329 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5330
5331 cx.update(|cx| {
5332 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5333 settings.tool_permissions.tools.insert(
5334 "save_file".into(),
5335 agent_settings::ToolRules {
5336 default_mode: settings::ToolPermissionMode::Allow,
5337 always_allow: vec![],
5338 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5339 always_confirm: vec![],
5340 invalid_patterns: vec![],
5341 },
5342 );
5343 agent_settings::AgentSettings::override_global(settings, cx);
5344 });
5345
5346 #[allow(clippy::arc_with_non_send_sync)]
5347 let tool = Arc::new(crate::SaveFileTool::new(project));
5348 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5349
5350 let task = cx.update(|cx| {
5351 tool.run(
5352 crate::SaveFileToolInput {
5353 paths: vec![
5354 std::path::PathBuf::from("root/normal.txt"),
5355 std::path::PathBuf::from("root/readonly/config.txt"),
5356 ],
5357 },
5358 event_stream,
5359 cx,
5360 )
5361 });
5362
5363 let result = task.await;
5364 assert!(
5365 result.is_err(),
5366 "expected save to be blocked due to denied path"
5367 );
5368 assert!(
5369 result.unwrap_err().to_string().contains("blocked"),
5370 "error should mention the save was blocked"
5371 );
5372}
5373
5374#[gpui::test]
5375async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5376 init_test(cx);
5377
5378 let fs = FakeFs::new(cx.executor());
5379 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5380 .await;
5381 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5382
5383 cx.update(|cx| {
5384 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5385 settings.always_allow_tool_actions = false;
5386 settings.tool_permissions.tools.insert(
5387 "save_file".into(),
5388 agent_settings::ToolRules {
5389 default_mode: settings::ToolPermissionMode::Allow,
5390 always_allow: vec![],
5391 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5392 always_confirm: vec![],
5393 invalid_patterns: vec![],
5394 },
5395 );
5396 agent_settings::AgentSettings::override_global(settings, cx);
5397 });
5398
5399 #[allow(clippy::arc_with_non_send_sync)]
5400 let tool = Arc::new(crate::SaveFileTool::new(project));
5401 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5402
5403 let task = cx.update(|cx| {
5404 tool.run(
5405 crate::SaveFileToolInput {
5406 paths: vec![std::path::PathBuf::from("root/config.secret")],
5407 },
5408 event_stream,
5409 cx,
5410 )
5411 });
5412
5413 let result = task.await;
5414 assert!(result.is_err(), "expected save to be blocked");
5415 assert!(
5416 result.unwrap_err().to_string().contains("blocked"),
5417 "error should mention the save was blocked"
5418 );
5419}
5420
5421#[gpui::test]
5422async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5423 init_test(cx);
5424
5425 cx.update(|cx| {
5426 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5427 settings.tool_permissions.tools.insert(
5428 "web_search".into(),
5429 agent_settings::ToolRules {
5430 default_mode: settings::ToolPermissionMode::Allow,
5431 always_allow: vec![],
5432 always_deny: vec![
5433 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5434 ],
5435 always_confirm: vec![],
5436 invalid_patterns: vec![],
5437 },
5438 );
5439 agent_settings::AgentSettings::override_global(settings, cx);
5440 });
5441
5442 #[allow(clippy::arc_with_non_send_sync)]
5443 let tool = Arc::new(crate::WebSearchTool);
5444 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5445
5446 let input: crate::WebSearchToolInput =
5447 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5448
5449 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5450
5451 let result = task.await;
5452 assert!(result.is_err(), "expected search to be blocked");
5453 assert!(
5454 result.unwrap_err().to_string().contains("blocked"),
5455 "error should mention the search was blocked"
5456 );
5457}
5458
5459#[gpui::test]
5460async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5461 init_test(cx);
5462
5463 let fs = FakeFs::new(cx.executor());
5464 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5465 .await;
5466 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5467
5468 cx.update(|cx| {
5469 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5470 settings.always_allow_tool_actions = false;
5471 settings.tool_permissions.tools.insert(
5472 "edit_file".into(),
5473 agent_settings::ToolRules {
5474 default_mode: settings::ToolPermissionMode::Confirm,
5475 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5476 always_deny: vec![],
5477 always_confirm: vec![],
5478 invalid_patterns: vec![],
5479 },
5480 );
5481 agent_settings::AgentSettings::override_global(settings, cx);
5482 });
5483
5484 let context_server_registry =
5485 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5486 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5487 let templates = crate::Templates::new();
5488 let thread = cx.new(|cx| {
5489 crate::Thread::new(
5490 project.clone(),
5491 cx.new(|_cx| prompt_store::ProjectContext::default()),
5492 context_server_registry,
5493 templates.clone(),
5494 None,
5495 cx,
5496 )
5497 });
5498
5499 #[allow(clippy::arc_with_non_send_sync)]
5500 let tool = Arc::new(crate::EditFileTool::new(
5501 project,
5502 thread.downgrade(),
5503 language_registry,
5504 templates,
5505 ));
5506 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5507
5508 let _task = cx.update(|cx| {
5509 tool.run(
5510 crate::EditFileToolInput {
5511 display_description: "Edit README".to_string(),
5512 path: "root/README.md".into(),
5513 mode: crate::EditFileMode::Edit,
5514 },
5515 event_stream,
5516 cx,
5517 )
5518 });
5519
5520 cx.run_until_parked();
5521
5522 let event = rx.try_next();
5523 assert!(
5524 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5525 "expected no authorization request for allowed .md file"
5526 );
5527}
5528
5529#[gpui::test]
5530async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5531 init_test(cx);
5532
5533 cx.update(|cx| {
5534 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5535 settings.tool_permissions.tools.insert(
5536 "fetch".into(),
5537 agent_settings::ToolRules {
5538 default_mode: settings::ToolPermissionMode::Allow,
5539 always_allow: vec![],
5540 always_deny: vec![
5541 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5542 ],
5543 always_confirm: vec![],
5544 invalid_patterns: vec![],
5545 },
5546 );
5547 agent_settings::AgentSettings::override_global(settings, cx);
5548 });
5549
5550 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5551
5552 #[allow(clippy::arc_with_non_send_sync)]
5553 let tool = Arc::new(crate::FetchTool::new(http_client));
5554 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5555
5556 let input: crate::FetchToolInput =
5557 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5558
5559 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5560
5561 let result = task.await;
5562 assert!(result.is_err(), "expected fetch to be blocked");
5563 assert!(
5564 result.unwrap_err().to_string().contains("blocked"),
5565 "error should mention the fetch was blocked"
5566 );
5567}
5568
5569#[gpui::test]
5570async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5571 init_test(cx);
5572
5573 cx.update(|cx| {
5574 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5575 settings.always_allow_tool_actions = false;
5576 settings.tool_permissions.tools.insert(
5577 "fetch".into(),
5578 agent_settings::ToolRules {
5579 default_mode: settings::ToolPermissionMode::Confirm,
5580 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5581 always_deny: vec![],
5582 always_confirm: vec![],
5583 invalid_patterns: vec![],
5584 },
5585 );
5586 agent_settings::AgentSettings::override_global(settings, cx);
5587 });
5588
5589 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5590
5591 #[allow(clippy::arc_with_non_send_sync)]
5592 let tool = Arc::new(crate::FetchTool::new(http_client));
5593 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5594
5595 let input: crate::FetchToolInput =
5596 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5597
5598 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5599
5600 cx.run_until_parked();
5601
5602 let event = rx.try_next();
5603 assert!(
5604 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5605 "expected no authorization request for allowed docs.rs URL"
5606 );
5607}
5608
5609#[gpui::test]
5610async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5611 init_test(cx);
5612 always_allow_tools(cx);
5613
5614 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5615 let fake_model = model.as_fake();
5616
5617 // Add a tool so we can simulate tool calls
5618 thread.update(cx, |thread, _cx| {
5619 thread.add_tool(EchoTool);
5620 });
5621
5622 // Start a turn by sending a message
5623 let mut events = thread
5624 .update(cx, |thread, cx| {
5625 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5626 })
5627 .unwrap();
5628 cx.run_until_parked();
5629
5630 // Simulate the model making a tool call
5631 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5632 LanguageModelToolUse {
5633 id: "tool_1".into(),
5634 name: "echo".into(),
5635 raw_input: r#"{"text": "hello"}"#.into(),
5636 input: json!({"text": "hello"}),
5637 is_input_complete: true,
5638 thought_signature: None,
5639 },
5640 ));
5641 fake_model
5642 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5643
5644 // Signal that a message is queued before ending the stream
5645 thread.update(cx, |thread, _cx| {
5646 thread.set_has_queued_message(true);
5647 });
5648
5649 // Now end the stream - tool will run, and the boundary check should see the queue
5650 fake_model.end_last_completion_stream();
5651
5652 // Collect all events until the turn stops
5653 let all_events = collect_events_until_stop(&mut events, cx).await;
5654
5655 // Verify we received the tool call event
5656 let tool_call_ids: Vec<_> = all_events
5657 .iter()
5658 .filter_map(|e| match e {
5659 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5660 _ => None,
5661 })
5662 .collect();
5663 assert_eq!(
5664 tool_call_ids,
5665 vec!["tool_1"],
5666 "Should have received a tool call event for our echo tool"
5667 );
5668
5669 // The turn should have stopped with EndTurn
5670 let stop_reasons = stop_events(all_events);
5671 assert_eq!(
5672 stop_reasons,
5673 vec![acp::StopReason::EndTurn],
5674 "Turn should have ended after tool completion due to queued message"
5675 );
5676
5677 // Verify the queued message flag is still set
5678 thread.update(cx, |thread, _cx| {
5679 assert!(
5680 thread.has_queued_message(),
5681 "Should still have queued message flag set"
5682 );
5683 });
5684
5685 // Thread should be idle now
5686 thread.update(cx, |thread, _cx| {
5687 assert!(
5688 thread.is_turn_complete(),
5689 "Thread should not be running after turn ends"
5690 );
5691 });
5692}