1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, UserMessageId,
4};
5use agent_client_protocol::{self as acp};
6use agent_settings::AgentProfileId;
7use anyhow::Result;
8use client::{Client, UserStore};
9use cloud_llm_client::CompletionIntent;
10use collections::IndexMap;
11use context_server::{ContextServer, ContextServerCommand, ContextServerId};
12use feature_flags::FeatureFlagAppExt as _;
13use fs::{FakeFs, Fs};
14use futures::{
15 FutureExt as _, StreamExt,
16 channel::{
17 mpsc::{self, UnboundedReceiver},
18 oneshot,
19 },
20 future::{Fuse, Shared},
21};
22use gpui::{
23 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
24 http_client::FakeHttpClient,
25};
26use indoc::indoc;
27use language_model::{
28 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
29 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
30 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
31 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
32};
33use pretty_assertions::assert_eq;
34use project::{
35 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
36};
37use prompt_store::ProjectContext;
38use reqwest_client::ReqwestClient;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use serde_json::json;
42use settings::{Settings, SettingsStore};
43use std::{
44 path::Path,
45 pin::Pin,
46 rc::Rc,
47 sync::{
48 Arc,
49 atomic::{AtomicBool, Ordering},
50 },
51 time::Duration,
52};
53use util::path;
54
55mod edit_file_thread_test;
56mod test_tools;
57use test_tools::*;
58
59fn init_test(cx: &mut TestAppContext) {
60 cx.update(|cx| {
61 let settings_store = SettingsStore::test(cx);
62 cx.set_global(settings_store);
63 });
64}
65
66struct FakeTerminalHandle {
67 killed: Arc<AtomicBool>,
68 stopped_by_user: Arc<AtomicBool>,
69 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
70 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
71 output: acp::TerminalOutputResponse,
72 id: acp::TerminalId,
73}
74
75impl FakeTerminalHandle {
76 fn new_never_exits(cx: &mut App) -> Self {
77 let killed = Arc::new(AtomicBool::new(false));
78 let stopped_by_user = Arc::new(AtomicBool::new(false));
79
80 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
81
82 let wait_for_exit = cx
83 .spawn(async move |_cx| {
84 // Wait for the exit signal (sent when kill() is called)
85 let _ = exit_receiver.await;
86 acp::TerminalExitStatus::new()
87 })
88 .shared();
89
90 Self {
91 killed,
92 stopped_by_user,
93 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
94 wait_for_exit,
95 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
96 id: acp::TerminalId::new("fake_terminal".to_string()),
97 }
98 }
99
100 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
101 let killed = Arc::new(AtomicBool::new(false));
102 let stopped_by_user = Arc::new(AtomicBool::new(false));
103 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
104
105 let wait_for_exit = cx
106 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
107 .shared();
108
109 Self {
110 killed,
111 stopped_by_user,
112 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
113 wait_for_exit,
114 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
115 id: acp::TerminalId::new("fake_terminal".to_string()),
116 }
117 }
118
119 fn was_killed(&self) -> bool {
120 self.killed.load(Ordering::SeqCst)
121 }
122
123 fn set_stopped_by_user(&self, stopped: bool) {
124 self.stopped_by_user.store(stopped, Ordering::SeqCst);
125 }
126
127 fn signal_exit(&self) {
128 if let Some(sender) = self.exit_sender.borrow_mut().take() {
129 let _ = sender.send(());
130 }
131 }
132}
133
134impl crate::TerminalHandle for FakeTerminalHandle {
135 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
136 Ok(self.id.clone())
137 }
138
139 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
140 Ok(self.output.clone())
141 }
142
143 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
144 Ok(self.wait_for_exit.clone())
145 }
146
147 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
148 self.killed.store(true, Ordering::SeqCst);
149 self.signal_exit();
150 Ok(())
151 }
152
153 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
154 Ok(self.stopped_by_user.load(Ordering::SeqCst))
155 }
156}
157
158struct FakeThreadEnvironment {
159 handle: Rc<FakeTerminalHandle>,
160}
161
162impl crate::ThreadEnvironment for FakeThreadEnvironment {
163 fn create_terminal(
164 &self,
165 _command: String,
166 _cwd: Option<std::path::PathBuf>,
167 _output_byte_limit: Option<u64>,
168 _cx: &mut AsyncApp,
169 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
170 Task::ready(Ok(self.handle.clone() as Rc<dyn crate::TerminalHandle>))
171 }
172}
173
174/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
175struct MultiTerminalEnvironment {
176 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
177}
178
179impl MultiTerminalEnvironment {
180 fn new() -> Self {
181 Self {
182 handles: std::cell::RefCell::new(Vec::new()),
183 }
184 }
185
186 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
187 self.handles.borrow().clone()
188 }
189}
190
191impl crate::ThreadEnvironment for MultiTerminalEnvironment {
192 fn create_terminal(
193 &self,
194 _command: String,
195 _cwd: Option<std::path::PathBuf>,
196 _output_byte_limit: Option<u64>,
197 cx: &mut AsyncApp,
198 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
199 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
200 self.handles.borrow_mut().push(handle.clone());
201 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
202 }
203}
204
205fn always_allow_tools(cx: &mut TestAppContext) {
206 cx.update(|cx| {
207 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
208 settings.always_allow_tool_actions = true;
209 agent_settings::AgentSettings::override_global(settings, cx);
210 });
211}
212
213#[gpui::test]
214async fn test_echo(cx: &mut TestAppContext) {
215 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
216 let fake_model = model.as_fake();
217
218 let events = thread
219 .update(cx, |thread, cx| {
220 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
221 })
222 .unwrap();
223 cx.run_until_parked();
224 fake_model.send_last_completion_stream_text_chunk("Hello");
225 fake_model
226 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
227 fake_model.end_last_completion_stream();
228
229 let events = events.collect().await;
230 thread.update(cx, |thread, _cx| {
231 assert_eq!(
232 thread.last_message().unwrap().to_markdown(),
233 indoc! {"
234 ## Assistant
235
236 Hello
237 "}
238 )
239 });
240 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
241}
242
243#[gpui::test]
244async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
245 init_test(cx);
246 always_allow_tools(cx);
247
248 let fs = FakeFs::new(cx.executor());
249 let project = Project::test(fs, [], cx).await;
250
251 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
252 let environment = Rc::new(FakeThreadEnvironment {
253 handle: handle.clone(),
254 });
255
256 #[allow(clippy::arc_with_non_send_sync)]
257 let tool = Arc::new(crate::TerminalTool::new(project, environment));
258 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
259
260 let task = cx.update(|cx| {
261 tool.run(
262 crate::TerminalToolInput {
263 command: "sleep 1000".to_string(),
264 cd: ".".to_string(),
265 timeout_ms: Some(5),
266 },
267 event_stream,
268 cx,
269 )
270 });
271
272 let update = rx.expect_update_fields().await;
273 assert!(
274 update.content.iter().any(|blocks| {
275 blocks
276 .iter()
277 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
278 }),
279 "expected tool call update to include terminal content"
280 );
281
282 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
283
284 let deadline = std::time::Instant::now() + Duration::from_millis(500);
285 loop {
286 if let Some(result) = task_future.as_mut().now_or_never() {
287 let result = result.expect("terminal tool task should complete");
288
289 assert!(
290 handle.was_killed(),
291 "expected terminal handle to be killed on timeout"
292 );
293 assert!(
294 result.contains("partial output"),
295 "expected result to include terminal output, got: {result}"
296 );
297 return;
298 }
299
300 if std::time::Instant::now() >= deadline {
301 panic!("timed out waiting for terminal tool task to complete");
302 }
303
304 cx.run_until_parked();
305 cx.background_executor.timer(Duration::from_millis(1)).await;
306 }
307}
308
309#[gpui::test]
310#[ignore]
311async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
312 init_test(cx);
313 always_allow_tools(cx);
314
315 let fs = FakeFs::new(cx.executor());
316 let project = Project::test(fs, [], cx).await;
317
318 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
319 let environment = Rc::new(FakeThreadEnvironment {
320 handle: handle.clone(),
321 });
322
323 #[allow(clippy::arc_with_non_send_sync)]
324 let tool = Arc::new(crate::TerminalTool::new(project, environment));
325 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
326
327 let _task = cx.update(|cx| {
328 tool.run(
329 crate::TerminalToolInput {
330 command: "sleep 1000".to_string(),
331 cd: ".".to_string(),
332 timeout_ms: None,
333 },
334 event_stream,
335 cx,
336 )
337 });
338
339 let update = rx.expect_update_fields().await;
340 assert!(
341 update.content.iter().any(|blocks| {
342 blocks
343 .iter()
344 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
345 }),
346 "expected tool call update to include terminal content"
347 );
348
349 cx.background_executor
350 .timer(Duration::from_millis(25))
351 .await;
352
353 assert!(
354 !handle.was_killed(),
355 "did not expect terminal handle to be killed without a timeout"
356 );
357}
358
359#[gpui::test]
360async fn test_thinking(cx: &mut TestAppContext) {
361 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
362 let fake_model = model.as_fake();
363
364 let events = thread
365 .update(cx, |thread, cx| {
366 thread.send(
367 UserMessageId::new(),
368 [indoc! {"
369 Testing:
370
371 Generate a thinking step where you just think the word 'Think',
372 and have your final answer be 'Hello'
373 "}],
374 cx,
375 )
376 })
377 .unwrap();
378 cx.run_until_parked();
379 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
380 text: "Think".to_string(),
381 signature: None,
382 });
383 fake_model.send_last_completion_stream_text_chunk("Hello");
384 fake_model
385 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
386 fake_model.end_last_completion_stream();
387
388 let events = events.collect().await;
389 thread.update(cx, |thread, _cx| {
390 assert_eq!(
391 thread.last_message().unwrap().to_markdown(),
392 indoc! {"
393 ## Assistant
394
395 <think>Think</think>
396 Hello
397 "}
398 )
399 });
400 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
401}
402
403#[gpui::test]
404async fn test_system_prompt(cx: &mut TestAppContext) {
405 let ThreadTest {
406 model,
407 thread,
408 project_context,
409 ..
410 } = setup(cx, TestModel::Fake).await;
411 let fake_model = model.as_fake();
412
413 project_context.update(cx, |project_context, _cx| {
414 project_context.shell = "test-shell".into()
415 });
416 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
417 thread
418 .update(cx, |thread, cx| {
419 thread.send(UserMessageId::new(), ["abc"], cx)
420 })
421 .unwrap();
422 cx.run_until_parked();
423 let mut pending_completions = fake_model.pending_completions();
424 assert_eq!(
425 pending_completions.len(),
426 1,
427 "unexpected pending completions: {:?}",
428 pending_completions
429 );
430
431 let pending_completion = pending_completions.pop().unwrap();
432 assert_eq!(pending_completion.messages[0].role, Role::System);
433
434 let system_message = &pending_completion.messages[0];
435 let system_prompt = system_message.content[0].to_str().unwrap();
436 assert!(
437 system_prompt.contains("test-shell"),
438 "unexpected system message: {:?}",
439 system_message
440 );
441 assert!(
442 system_prompt.contains("## Fixing Diagnostics"),
443 "unexpected system message: {:?}",
444 system_message
445 );
446}
447
448#[gpui::test]
449async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
450 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
451 let fake_model = model.as_fake();
452
453 thread
454 .update(cx, |thread, cx| {
455 thread.send(UserMessageId::new(), ["abc"], cx)
456 })
457 .unwrap();
458 cx.run_until_parked();
459 let mut pending_completions = fake_model.pending_completions();
460 assert_eq!(
461 pending_completions.len(),
462 1,
463 "unexpected pending completions: {:?}",
464 pending_completions
465 );
466
467 let pending_completion = pending_completions.pop().unwrap();
468 assert_eq!(pending_completion.messages[0].role, Role::System);
469
470 let system_message = &pending_completion.messages[0];
471 let system_prompt = system_message.content[0].to_str().unwrap();
472 assert!(
473 !system_prompt.contains("## Tool Use"),
474 "unexpected system message: {:?}",
475 system_message
476 );
477 assert!(
478 !system_prompt.contains("## Fixing Diagnostics"),
479 "unexpected system message: {:?}",
480 system_message
481 );
482}
483
484#[gpui::test]
485async fn test_prompt_caching(cx: &mut TestAppContext) {
486 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
487 let fake_model = model.as_fake();
488
489 // Send initial user message and verify it's cached
490 thread
491 .update(cx, |thread, cx| {
492 thread.send(UserMessageId::new(), ["Message 1"], cx)
493 })
494 .unwrap();
495 cx.run_until_parked();
496
497 let completion = fake_model.pending_completions().pop().unwrap();
498 assert_eq!(
499 completion.messages[1..],
500 vec![LanguageModelRequestMessage {
501 role: Role::User,
502 content: vec!["Message 1".into()],
503 cache: true,
504 reasoning_details: None,
505 }]
506 );
507 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
508 "Response to Message 1".into(),
509 ));
510 fake_model.end_last_completion_stream();
511 cx.run_until_parked();
512
513 // Send another user message and verify only the latest is cached
514 thread
515 .update(cx, |thread, cx| {
516 thread.send(UserMessageId::new(), ["Message 2"], cx)
517 })
518 .unwrap();
519 cx.run_until_parked();
520
521 let completion = fake_model.pending_completions().pop().unwrap();
522 assert_eq!(
523 completion.messages[1..],
524 vec![
525 LanguageModelRequestMessage {
526 role: Role::User,
527 content: vec!["Message 1".into()],
528 cache: false,
529 reasoning_details: None,
530 },
531 LanguageModelRequestMessage {
532 role: Role::Assistant,
533 content: vec!["Response to Message 1".into()],
534 cache: false,
535 reasoning_details: None,
536 },
537 LanguageModelRequestMessage {
538 role: Role::User,
539 content: vec!["Message 2".into()],
540 cache: true,
541 reasoning_details: None,
542 }
543 ]
544 );
545 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
546 "Response to Message 2".into(),
547 ));
548 fake_model.end_last_completion_stream();
549 cx.run_until_parked();
550
551 // Simulate a tool call and verify that the latest tool result is cached
552 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
553 thread
554 .update(cx, |thread, cx| {
555 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
556 })
557 .unwrap();
558 cx.run_until_parked();
559
560 let tool_use = LanguageModelToolUse {
561 id: "tool_1".into(),
562 name: EchoTool::name().into(),
563 raw_input: json!({"text": "test"}).to_string(),
564 input: json!({"text": "test"}),
565 is_input_complete: true,
566 thought_signature: None,
567 };
568 fake_model
569 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
570 fake_model.end_last_completion_stream();
571 cx.run_until_parked();
572
573 let completion = fake_model.pending_completions().pop().unwrap();
574 let tool_result = LanguageModelToolResult {
575 tool_use_id: "tool_1".into(),
576 tool_name: EchoTool::name().into(),
577 is_error: false,
578 content: "test".into(),
579 output: Some("test".into()),
580 };
581 assert_eq!(
582 completion.messages[1..],
583 vec![
584 LanguageModelRequestMessage {
585 role: Role::User,
586 content: vec!["Message 1".into()],
587 cache: false,
588 reasoning_details: None,
589 },
590 LanguageModelRequestMessage {
591 role: Role::Assistant,
592 content: vec!["Response to Message 1".into()],
593 cache: false,
594 reasoning_details: None,
595 },
596 LanguageModelRequestMessage {
597 role: Role::User,
598 content: vec!["Message 2".into()],
599 cache: false,
600 reasoning_details: None,
601 },
602 LanguageModelRequestMessage {
603 role: Role::Assistant,
604 content: vec!["Response to Message 2".into()],
605 cache: false,
606 reasoning_details: None,
607 },
608 LanguageModelRequestMessage {
609 role: Role::User,
610 content: vec!["Use the echo tool".into()],
611 cache: false,
612 reasoning_details: None,
613 },
614 LanguageModelRequestMessage {
615 role: Role::Assistant,
616 content: vec![MessageContent::ToolUse(tool_use)],
617 cache: false,
618 reasoning_details: None,
619 },
620 LanguageModelRequestMessage {
621 role: Role::User,
622 content: vec![MessageContent::ToolResult(tool_result)],
623 cache: true,
624 reasoning_details: None,
625 }
626 ]
627 );
628}
629
630#[gpui::test]
631#[cfg_attr(not(feature = "e2e"), ignore)]
632async fn test_basic_tool_calls(cx: &mut TestAppContext) {
633 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
634
635 // Test a tool call that's likely to complete *before* streaming stops.
636 let events = thread
637 .update(cx, |thread, cx| {
638 thread.add_tool(EchoTool);
639 thread.send(
640 UserMessageId::new(),
641 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
642 cx,
643 )
644 })
645 .unwrap()
646 .collect()
647 .await;
648 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
649
650 // Test a tool calls that's likely to complete *after* streaming stops.
651 let events = thread
652 .update(cx, |thread, cx| {
653 thread.remove_tool(&EchoTool::name());
654 thread.add_tool(DelayTool);
655 thread.send(
656 UserMessageId::new(),
657 [
658 "Now call the delay tool with 200ms.",
659 "When the timer goes off, then you echo the output of the tool.",
660 ],
661 cx,
662 )
663 })
664 .unwrap()
665 .collect()
666 .await;
667 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
668 thread.update(cx, |thread, _cx| {
669 assert!(
670 thread
671 .last_message()
672 .unwrap()
673 .as_agent_message()
674 .unwrap()
675 .content
676 .iter()
677 .any(|content| {
678 if let AgentMessageContent::Text(text) = content {
679 text.contains("Ding")
680 } else {
681 false
682 }
683 }),
684 "{}",
685 thread.to_markdown()
686 );
687 });
688}
689
690#[gpui::test]
691#[cfg_attr(not(feature = "e2e"), ignore)]
692async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
693 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
694
695 // Test a tool call that's likely to complete *before* streaming stops.
696 let mut events = thread
697 .update(cx, |thread, cx| {
698 thread.add_tool(WordListTool);
699 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
700 })
701 .unwrap();
702
703 let mut saw_partial_tool_use = false;
704 while let Some(event) = events.next().await {
705 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
706 thread.update(cx, |thread, _cx| {
707 // Look for a tool use in the thread's last message
708 let message = thread.last_message().unwrap();
709 let agent_message = message.as_agent_message().unwrap();
710 let last_content = agent_message.content.last().unwrap();
711 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
712 assert_eq!(last_tool_use.name.as_ref(), "word_list");
713 if tool_call.status == acp::ToolCallStatus::Pending {
714 if !last_tool_use.is_input_complete
715 && last_tool_use.input.get("g").is_none()
716 {
717 saw_partial_tool_use = true;
718 }
719 } else {
720 last_tool_use
721 .input
722 .get("a")
723 .expect("'a' has streamed because input is now complete");
724 last_tool_use
725 .input
726 .get("g")
727 .expect("'g' has streamed because input is now complete");
728 }
729 } else {
730 panic!("last content should be a tool use");
731 }
732 });
733 }
734 }
735
736 assert!(
737 saw_partial_tool_use,
738 "should see at least one partially streamed tool use in the history"
739 );
740}
741
742#[gpui::test]
743async fn test_tool_authorization(cx: &mut TestAppContext) {
744 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
745 let fake_model = model.as_fake();
746
747 let mut events = thread
748 .update(cx, |thread, cx| {
749 thread.add_tool(ToolRequiringPermission);
750 thread.send(UserMessageId::new(), ["abc"], cx)
751 })
752 .unwrap();
753 cx.run_until_parked();
754 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
755 LanguageModelToolUse {
756 id: "tool_id_1".into(),
757 name: ToolRequiringPermission::name().into(),
758 raw_input: "{}".into(),
759 input: json!({}),
760 is_input_complete: true,
761 thought_signature: None,
762 },
763 ));
764 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
765 LanguageModelToolUse {
766 id: "tool_id_2".into(),
767 name: ToolRequiringPermission::name().into(),
768 raw_input: "{}".into(),
769 input: json!({}),
770 is_input_complete: true,
771 thought_signature: None,
772 },
773 ));
774 fake_model.end_last_completion_stream();
775 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
776 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
777
778 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
779 tool_call_auth_1
780 .response
781 .send(acp::PermissionOptionId::new("allow"))
782 .unwrap();
783 cx.run_until_parked();
784
785 // Reject the second - send "deny" option_id directly since Deny is now a button
786 tool_call_auth_2
787 .response
788 .send(acp::PermissionOptionId::new("deny"))
789 .unwrap();
790 cx.run_until_parked();
791
792 let completion = fake_model.pending_completions().pop().unwrap();
793 let message = completion.messages.last().unwrap();
794 assert_eq!(
795 message.content,
796 vec![
797 language_model::MessageContent::ToolResult(LanguageModelToolResult {
798 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
799 tool_name: ToolRequiringPermission::name().into(),
800 is_error: false,
801 content: "Allowed".into(),
802 output: Some("Allowed".into())
803 }),
804 language_model::MessageContent::ToolResult(LanguageModelToolResult {
805 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
806 tool_name: ToolRequiringPermission::name().into(),
807 is_error: true,
808 content: "Permission to run tool denied by user".into(),
809 output: Some("Permission to run tool denied by user".into())
810 })
811 ]
812 );
813
814 // Simulate yet another tool call.
815 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
816 LanguageModelToolUse {
817 id: "tool_id_3".into(),
818 name: ToolRequiringPermission::name().into(),
819 raw_input: "{}".into(),
820 input: json!({}),
821 is_input_complete: true,
822 thought_signature: None,
823 },
824 ));
825 fake_model.end_last_completion_stream();
826
827 // Respond by always allowing tools - send transformed option_id
828 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
829 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
830 tool_call_auth_3
831 .response
832 .send(acp::PermissionOptionId::new(
833 "always_allow:tool_requiring_permission",
834 ))
835 .unwrap();
836 cx.run_until_parked();
837 let completion = fake_model.pending_completions().pop().unwrap();
838 let message = completion.messages.last().unwrap();
839 assert_eq!(
840 message.content,
841 vec![language_model::MessageContent::ToolResult(
842 LanguageModelToolResult {
843 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
844 tool_name: ToolRequiringPermission::name().into(),
845 is_error: false,
846 content: "Allowed".into(),
847 output: Some("Allowed".into())
848 }
849 )]
850 );
851
852 // Simulate a final tool call, ensuring we don't trigger authorization.
853 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
854 LanguageModelToolUse {
855 id: "tool_id_4".into(),
856 name: ToolRequiringPermission::name().into(),
857 raw_input: "{}".into(),
858 input: json!({}),
859 is_input_complete: true,
860 thought_signature: None,
861 },
862 ));
863 fake_model.end_last_completion_stream();
864 cx.run_until_parked();
865 let completion = fake_model.pending_completions().pop().unwrap();
866 let message = completion.messages.last().unwrap();
867 assert_eq!(
868 message.content,
869 vec![language_model::MessageContent::ToolResult(
870 LanguageModelToolResult {
871 tool_use_id: "tool_id_4".into(),
872 tool_name: ToolRequiringPermission::name().into(),
873 is_error: false,
874 content: "Allowed".into(),
875 output: Some("Allowed".into())
876 }
877 )]
878 );
879}
880
881#[gpui::test]
882async fn test_tool_hallucination(cx: &mut TestAppContext) {
883 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
884 let fake_model = model.as_fake();
885
886 let mut events = thread
887 .update(cx, |thread, cx| {
888 thread.send(UserMessageId::new(), ["abc"], cx)
889 })
890 .unwrap();
891 cx.run_until_parked();
892 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
893 LanguageModelToolUse {
894 id: "tool_id_1".into(),
895 name: "nonexistent_tool".into(),
896 raw_input: "{}".into(),
897 input: json!({}),
898 is_input_complete: true,
899 thought_signature: None,
900 },
901 ));
902 fake_model.end_last_completion_stream();
903
904 let tool_call = expect_tool_call(&mut events).await;
905 assert_eq!(tool_call.title, "nonexistent_tool");
906 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
907 let update = expect_tool_call_update_fields(&mut events).await;
908 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
909}
910
911async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
912 let event = events
913 .next()
914 .await
915 .expect("no tool call authorization event received")
916 .unwrap();
917 match event {
918 ThreadEvent::ToolCall(tool_call) => tool_call,
919 event => {
920 panic!("Unexpected event {event:?}");
921 }
922 }
923}
924
925async fn expect_tool_call_update_fields(
926 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
927) -> acp::ToolCallUpdate {
928 let event = events
929 .next()
930 .await
931 .expect("no tool call authorization event received")
932 .unwrap();
933 match event {
934 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
935 event => {
936 panic!("Unexpected event {event:?}");
937 }
938 }
939}
940
941async fn next_tool_call_authorization(
942 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
943) -> ToolCallAuthorization {
944 loop {
945 let event = events
946 .next()
947 .await
948 .expect("no tool call authorization event received")
949 .unwrap();
950 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
951 let permission_kinds = tool_call_authorization
952 .options
953 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
954 .map(|option| option.kind);
955 let allow_once = tool_call_authorization
956 .options
957 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
958 .map(|option| option.kind);
959
960 assert_eq!(
961 permission_kinds,
962 Some(acp::PermissionOptionKind::AllowAlways)
963 );
964 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
965 return tool_call_authorization;
966 }
967 }
968}
969
970#[test]
971fn test_permission_options_terminal_with_pattern() {
972 let permission_options =
973 ToolPermissionContext::new("terminal", "cargo build --release").build_permission_options();
974
975 let PermissionOptions::Dropdown(choices) = permission_options else {
976 panic!("Expected dropdown permission options");
977 };
978
979 assert_eq!(choices.len(), 3);
980 let labels: Vec<&str> = choices
981 .iter()
982 .map(|choice| choice.allow.name.as_ref())
983 .collect();
984 assert!(labels.contains(&"Always for terminal"));
985 assert!(labels.contains(&"Always for `cargo` commands"));
986 assert!(labels.contains(&"Only this time"));
987}
988
989#[test]
990fn test_permission_options_edit_file_with_path_pattern() {
991 let permission_options =
992 ToolPermissionContext::new("edit_file", "src/main.rs").build_permission_options();
993
994 let PermissionOptions::Dropdown(choices) = permission_options else {
995 panic!("Expected dropdown permission options");
996 };
997
998 let labels: Vec<&str> = choices
999 .iter()
1000 .map(|choice| choice.allow.name.as_ref())
1001 .collect();
1002 assert!(labels.contains(&"Always for edit file"));
1003 assert!(labels.contains(&"Always for `src/`"));
1004}
1005
1006#[test]
1007fn test_permission_options_fetch_with_domain_pattern() {
1008 let permission_options =
1009 ToolPermissionContext::new("fetch", "https://docs.rs/gpui").build_permission_options();
1010
1011 let PermissionOptions::Dropdown(choices) = permission_options else {
1012 panic!("Expected dropdown permission options");
1013 };
1014
1015 let labels: Vec<&str> = choices
1016 .iter()
1017 .map(|choice| choice.allow.name.as_ref())
1018 .collect();
1019 assert!(labels.contains(&"Always for fetch"));
1020 assert!(labels.contains(&"Always for `docs.rs`"));
1021}
1022
1023#[test]
1024fn test_permission_options_without_pattern() {
1025 let permission_options = ToolPermissionContext::new("terminal", "./deploy.sh --production")
1026 .build_permission_options();
1027
1028 let PermissionOptions::Dropdown(choices) = permission_options else {
1029 panic!("Expected dropdown permission options");
1030 };
1031
1032 assert_eq!(choices.len(), 2);
1033 let labels: Vec<&str> = choices
1034 .iter()
1035 .map(|choice| choice.allow.name.as_ref())
1036 .collect();
1037 assert!(labels.contains(&"Always for terminal"));
1038 assert!(labels.contains(&"Only this time"));
1039 assert!(!labels.iter().any(|label| label.contains("commands")));
1040}
1041
1042#[test]
1043fn test_permission_option_ids_for_terminal() {
1044 let permission_options =
1045 ToolPermissionContext::new("terminal", "cargo build --release").build_permission_options();
1046
1047 let PermissionOptions::Dropdown(choices) = permission_options else {
1048 panic!("Expected dropdown permission options");
1049 };
1050
1051 let allow_ids: Vec<String> = choices
1052 .iter()
1053 .map(|choice| choice.allow.option_id.0.to_string())
1054 .collect();
1055 let deny_ids: Vec<String> = choices
1056 .iter()
1057 .map(|choice| choice.deny.option_id.0.to_string())
1058 .collect();
1059
1060 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1061 assert!(allow_ids.contains(&"allow".to_string()));
1062 assert!(
1063 allow_ids
1064 .iter()
1065 .any(|id| id.starts_with("always_allow_pattern:terminal:")),
1066 "Missing allow pattern option"
1067 );
1068
1069 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1070 assert!(deny_ids.contains(&"deny".to_string()));
1071 assert!(
1072 deny_ids
1073 .iter()
1074 .any(|id| id.starts_with("always_deny_pattern:terminal:")),
1075 "Missing deny pattern option"
1076 );
1077}
1078
1079#[gpui::test]
1080#[cfg_attr(not(feature = "e2e"), ignore)]
1081async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1082 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1083
1084 // Test concurrent tool calls with different delay times
1085 let events = thread
1086 .update(cx, |thread, cx| {
1087 thread.add_tool(DelayTool);
1088 thread.send(
1089 UserMessageId::new(),
1090 [
1091 "Call the delay tool twice in the same message.",
1092 "Once with 100ms. Once with 300ms.",
1093 "When both timers are complete, describe the outputs.",
1094 ],
1095 cx,
1096 )
1097 })
1098 .unwrap()
1099 .collect()
1100 .await;
1101
1102 let stop_reasons = stop_events(events);
1103 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1104
1105 thread.update(cx, |thread, _cx| {
1106 let last_message = thread.last_message().unwrap();
1107 let agent_message = last_message.as_agent_message().unwrap();
1108 let text = agent_message
1109 .content
1110 .iter()
1111 .filter_map(|content| {
1112 if let AgentMessageContent::Text(text) = content {
1113 Some(text.as_str())
1114 } else {
1115 None
1116 }
1117 })
1118 .collect::<String>();
1119
1120 assert!(text.contains("Ding"));
1121 });
1122}
1123
1124#[gpui::test]
1125async fn test_profiles(cx: &mut TestAppContext) {
1126 let ThreadTest {
1127 model, thread, fs, ..
1128 } = setup(cx, TestModel::Fake).await;
1129 let fake_model = model.as_fake();
1130
1131 thread.update(cx, |thread, _cx| {
1132 thread.add_tool(DelayTool);
1133 thread.add_tool(EchoTool);
1134 thread.add_tool(InfiniteTool);
1135 });
1136
1137 // Override profiles and wait for settings to be loaded.
1138 fs.insert_file(
1139 paths::settings_file(),
1140 json!({
1141 "agent": {
1142 "profiles": {
1143 "test-1": {
1144 "name": "Test Profile 1",
1145 "tools": {
1146 EchoTool::name(): true,
1147 DelayTool::name(): true,
1148 }
1149 },
1150 "test-2": {
1151 "name": "Test Profile 2",
1152 "tools": {
1153 InfiniteTool::name(): true,
1154 }
1155 }
1156 }
1157 }
1158 })
1159 .to_string()
1160 .into_bytes(),
1161 )
1162 .await;
1163 cx.run_until_parked();
1164
1165 // Test that test-1 profile (default) has echo and delay tools
1166 thread
1167 .update(cx, |thread, cx| {
1168 thread.set_profile(AgentProfileId("test-1".into()), cx);
1169 thread.send(UserMessageId::new(), ["test"], cx)
1170 })
1171 .unwrap();
1172 cx.run_until_parked();
1173
1174 let mut pending_completions = fake_model.pending_completions();
1175 assert_eq!(pending_completions.len(), 1);
1176 let completion = pending_completions.pop().unwrap();
1177 let tool_names: Vec<String> = completion
1178 .tools
1179 .iter()
1180 .map(|tool| tool.name.clone())
1181 .collect();
1182 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
1183 fake_model.end_last_completion_stream();
1184
1185 // Switch to test-2 profile, and verify that it has only the infinite tool.
1186 thread
1187 .update(cx, |thread, cx| {
1188 thread.set_profile(AgentProfileId("test-2".into()), cx);
1189 thread.send(UserMessageId::new(), ["test2"], cx)
1190 })
1191 .unwrap();
1192 cx.run_until_parked();
1193 let mut pending_completions = fake_model.pending_completions();
1194 assert_eq!(pending_completions.len(), 1);
1195 let completion = pending_completions.pop().unwrap();
1196 let tool_names: Vec<String> = completion
1197 .tools
1198 .iter()
1199 .map(|tool| tool.name.clone())
1200 .collect();
1201 assert_eq!(tool_names, vec![InfiniteTool::name()]);
1202}
1203
1204#[gpui::test]
1205async fn test_mcp_tools(cx: &mut TestAppContext) {
1206 let ThreadTest {
1207 model,
1208 thread,
1209 context_server_store,
1210 fs,
1211 ..
1212 } = setup(cx, TestModel::Fake).await;
1213 let fake_model = model.as_fake();
1214
1215 // Override profiles and wait for settings to be loaded.
1216 fs.insert_file(
1217 paths::settings_file(),
1218 json!({
1219 "agent": {
1220 "always_allow_tool_actions": true,
1221 "profiles": {
1222 "test": {
1223 "name": "Test Profile",
1224 "enable_all_context_servers": true,
1225 "tools": {
1226 EchoTool::name(): true,
1227 }
1228 },
1229 }
1230 }
1231 })
1232 .to_string()
1233 .into_bytes(),
1234 )
1235 .await;
1236 cx.run_until_parked();
1237 thread.update(cx, |thread, cx| {
1238 thread.set_profile(AgentProfileId("test".into()), cx)
1239 });
1240
1241 let mut mcp_tool_calls = setup_context_server(
1242 "test_server",
1243 vec![context_server::types::Tool {
1244 name: "echo".into(),
1245 description: None,
1246 input_schema: serde_json::to_value(EchoTool::input_schema(
1247 LanguageModelToolSchemaFormat::JsonSchema,
1248 ))
1249 .unwrap(),
1250 output_schema: None,
1251 annotations: None,
1252 }],
1253 &context_server_store,
1254 cx,
1255 );
1256
1257 let events = thread.update(cx, |thread, cx| {
1258 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1259 });
1260 cx.run_until_parked();
1261
1262 // Simulate the model calling the MCP tool.
1263 let completion = fake_model.pending_completions().pop().unwrap();
1264 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1265 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1266 LanguageModelToolUse {
1267 id: "tool_1".into(),
1268 name: "echo".into(),
1269 raw_input: json!({"text": "test"}).to_string(),
1270 input: json!({"text": "test"}),
1271 is_input_complete: true,
1272 thought_signature: None,
1273 },
1274 ));
1275 fake_model.end_last_completion_stream();
1276 cx.run_until_parked();
1277
1278 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1279 assert_eq!(tool_call_params.name, "echo");
1280 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1281 tool_call_response
1282 .send(context_server::types::CallToolResponse {
1283 content: vec![context_server::types::ToolResponseContent::Text {
1284 text: "test".into(),
1285 }],
1286 is_error: None,
1287 meta: None,
1288 structured_content: None,
1289 })
1290 .unwrap();
1291 cx.run_until_parked();
1292
1293 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1294 fake_model.send_last_completion_stream_text_chunk("Done!");
1295 fake_model.end_last_completion_stream();
1296 events.collect::<Vec<_>>().await;
1297
1298 // Send again after adding the echo tool, ensuring the name collision is resolved.
1299 let events = thread.update(cx, |thread, cx| {
1300 thread.add_tool(EchoTool);
1301 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1302 });
1303 cx.run_until_parked();
1304 let completion = fake_model.pending_completions().pop().unwrap();
1305 assert_eq!(
1306 tool_names_for_completion(&completion),
1307 vec!["echo", "test_server_echo"]
1308 );
1309 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1310 LanguageModelToolUse {
1311 id: "tool_2".into(),
1312 name: "test_server_echo".into(),
1313 raw_input: json!({"text": "mcp"}).to_string(),
1314 input: json!({"text": "mcp"}),
1315 is_input_complete: true,
1316 thought_signature: None,
1317 },
1318 ));
1319 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1320 LanguageModelToolUse {
1321 id: "tool_3".into(),
1322 name: "echo".into(),
1323 raw_input: json!({"text": "native"}).to_string(),
1324 input: json!({"text": "native"}),
1325 is_input_complete: true,
1326 thought_signature: None,
1327 },
1328 ));
1329 fake_model.end_last_completion_stream();
1330 cx.run_until_parked();
1331
1332 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1333 assert_eq!(tool_call_params.name, "echo");
1334 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1335 tool_call_response
1336 .send(context_server::types::CallToolResponse {
1337 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1338 is_error: None,
1339 meta: None,
1340 structured_content: None,
1341 })
1342 .unwrap();
1343 cx.run_until_parked();
1344
1345 // Ensure the tool results were inserted with the correct names.
1346 let completion = fake_model.pending_completions().pop().unwrap();
1347 assert_eq!(
1348 completion.messages.last().unwrap().content,
1349 vec![
1350 MessageContent::ToolResult(LanguageModelToolResult {
1351 tool_use_id: "tool_3".into(),
1352 tool_name: "echo".into(),
1353 is_error: false,
1354 content: "native".into(),
1355 output: Some("native".into()),
1356 },),
1357 MessageContent::ToolResult(LanguageModelToolResult {
1358 tool_use_id: "tool_2".into(),
1359 tool_name: "test_server_echo".into(),
1360 is_error: false,
1361 content: "mcp".into(),
1362 output: Some("mcp".into()),
1363 },),
1364 ]
1365 );
1366 fake_model.end_last_completion_stream();
1367 events.collect::<Vec<_>>().await;
1368}
1369
1370#[gpui::test]
1371async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1372 let ThreadTest {
1373 model,
1374 thread,
1375 context_server_store,
1376 fs,
1377 ..
1378 } = setup(cx, TestModel::Fake).await;
1379 let fake_model = model.as_fake();
1380
1381 // Set up a profile with all tools enabled
1382 fs.insert_file(
1383 paths::settings_file(),
1384 json!({
1385 "agent": {
1386 "profiles": {
1387 "test": {
1388 "name": "Test Profile",
1389 "enable_all_context_servers": true,
1390 "tools": {
1391 EchoTool::name(): true,
1392 DelayTool::name(): true,
1393 WordListTool::name(): true,
1394 ToolRequiringPermission::name(): true,
1395 InfiniteTool::name(): true,
1396 }
1397 },
1398 }
1399 }
1400 })
1401 .to_string()
1402 .into_bytes(),
1403 )
1404 .await;
1405 cx.run_until_parked();
1406
1407 thread.update(cx, |thread, cx| {
1408 thread.set_profile(AgentProfileId("test".into()), cx);
1409 thread.add_tool(EchoTool);
1410 thread.add_tool(DelayTool);
1411 thread.add_tool(WordListTool);
1412 thread.add_tool(ToolRequiringPermission);
1413 thread.add_tool(InfiniteTool);
1414 });
1415
1416 // Set up multiple context servers with some overlapping tool names
1417 let _server1_calls = setup_context_server(
1418 "xxx",
1419 vec![
1420 context_server::types::Tool {
1421 name: "echo".into(), // Conflicts with native EchoTool
1422 description: None,
1423 input_schema: serde_json::to_value(EchoTool::input_schema(
1424 LanguageModelToolSchemaFormat::JsonSchema,
1425 ))
1426 .unwrap(),
1427 output_schema: None,
1428 annotations: None,
1429 },
1430 context_server::types::Tool {
1431 name: "unique_tool_1".into(),
1432 description: None,
1433 input_schema: json!({"type": "object", "properties": {}}),
1434 output_schema: None,
1435 annotations: None,
1436 },
1437 ],
1438 &context_server_store,
1439 cx,
1440 );
1441
1442 let _server2_calls = setup_context_server(
1443 "yyy",
1444 vec![
1445 context_server::types::Tool {
1446 name: "echo".into(), // Also conflicts with native EchoTool
1447 description: None,
1448 input_schema: serde_json::to_value(EchoTool::input_schema(
1449 LanguageModelToolSchemaFormat::JsonSchema,
1450 ))
1451 .unwrap(),
1452 output_schema: None,
1453 annotations: None,
1454 },
1455 context_server::types::Tool {
1456 name: "unique_tool_2".into(),
1457 description: None,
1458 input_schema: json!({"type": "object", "properties": {}}),
1459 output_schema: None,
1460 annotations: None,
1461 },
1462 context_server::types::Tool {
1463 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1464 description: None,
1465 input_schema: json!({"type": "object", "properties": {}}),
1466 output_schema: None,
1467 annotations: None,
1468 },
1469 context_server::types::Tool {
1470 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1471 description: None,
1472 input_schema: json!({"type": "object", "properties": {}}),
1473 output_schema: None,
1474 annotations: None,
1475 },
1476 ],
1477 &context_server_store,
1478 cx,
1479 );
1480 let _server3_calls = setup_context_server(
1481 "zzz",
1482 vec![
1483 context_server::types::Tool {
1484 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1485 description: None,
1486 input_schema: json!({"type": "object", "properties": {}}),
1487 output_schema: None,
1488 annotations: None,
1489 },
1490 context_server::types::Tool {
1491 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1492 description: None,
1493 input_schema: json!({"type": "object", "properties": {}}),
1494 output_schema: None,
1495 annotations: None,
1496 },
1497 context_server::types::Tool {
1498 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1499 description: None,
1500 input_schema: json!({"type": "object", "properties": {}}),
1501 output_schema: None,
1502 annotations: None,
1503 },
1504 ],
1505 &context_server_store,
1506 cx,
1507 );
1508
1509 thread
1510 .update(cx, |thread, cx| {
1511 thread.send(UserMessageId::new(), ["Go"], cx)
1512 })
1513 .unwrap();
1514 cx.run_until_parked();
1515 let completion = fake_model.pending_completions().pop().unwrap();
1516 assert_eq!(
1517 tool_names_for_completion(&completion),
1518 vec![
1519 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1520 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1521 "delay",
1522 "echo",
1523 "infinite",
1524 "tool_requiring_permission",
1525 "unique_tool_1",
1526 "unique_tool_2",
1527 "word_list",
1528 "xxx_echo",
1529 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1530 "yyy_echo",
1531 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1532 ]
1533 );
1534}
1535
1536#[gpui::test]
1537#[cfg_attr(not(feature = "e2e"), ignore)]
1538async fn test_cancellation(cx: &mut TestAppContext) {
1539 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1540
1541 let mut events = thread
1542 .update(cx, |thread, cx| {
1543 thread.add_tool(InfiniteTool);
1544 thread.add_tool(EchoTool);
1545 thread.send(
1546 UserMessageId::new(),
1547 ["Call the echo tool, then call the infinite tool, then explain their output"],
1548 cx,
1549 )
1550 })
1551 .unwrap();
1552
1553 // Wait until both tools are called.
1554 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1555 let mut echo_id = None;
1556 let mut echo_completed = false;
1557 while let Some(event) = events.next().await {
1558 match event.unwrap() {
1559 ThreadEvent::ToolCall(tool_call) => {
1560 assert_eq!(tool_call.title, expected_tools.remove(0));
1561 if tool_call.title == "Echo" {
1562 echo_id = Some(tool_call.tool_call_id);
1563 }
1564 }
1565 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1566 acp::ToolCallUpdate {
1567 tool_call_id,
1568 fields:
1569 acp::ToolCallUpdateFields {
1570 status: Some(acp::ToolCallStatus::Completed),
1571 ..
1572 },
1573 ..
1574 },
1575 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1576 echo_completed = true;
1577 }
1578 _ => {}
1579 }
1580
1581 if expected_tools.is_empty() && echo_completed {
1582 break;
1583 }
1584 }
1585
1586 // Cancel the current send and ensure that the event stream is closed, even
1587 // if one of the tools is still running.
1588 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1589 let events = events.collect::<Vec<_>>().await;
1590 let last_event = events.last();
1591 assert!(
1592 matches!(
1593 last_event,
1594 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1595 ),
1596 "unexpected event {last_event:?}"
1597 );
1598
1599 // Ensure we can still send a new message after cancellation.
1600 let events = thread
1601 .update(cx, |thread, cx| {
1602 thread.send(
1603 UserMessageId::new(),
1604 ["Testing: reply with 'Hello' then stop."],
1605 cx,
1606 )
1607 })
1608 .unwrap()
1609 .collect::<Vec<_>>()
1610 .await;
1611 thread.update(cx, |thread, _cx| {
1612 let message = thread.last_message().unwrap();
1613 let agent_message = message.as_agent_message().unwrap();
1614 assert_eq!(
1615 agent_message.content,
1616 vec![AgentMessageContent::Text("Hello".to_string())]
1617 );
1618 });
1619 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1620}
1621
1622#[gpui::test]
1623async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1624 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1625 always_allow_tools(cx);
1626 let fake_model = model.as_fake();
1627
1628 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1629 let environment = Rc::new(FakeThreadEnvironment {
1630 handle: handle.clone(),
1631 });
1632
1633 let mut events = thread
1634 .update(cx, |thread, cx| {
1635 thread.add_tool(crate::TerminalTool::new(
1636 thread.project().clone(),
1637 environment,
1638 ));
1639 thread.send(UserMessageId::new(), ["run a command"], cx)
1640 })
1641 .unwrap();
1642
1643 cx.run_until_parked();
1644
1645 // Simulate the model calling the terminal tool
1646 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1647 LanguageModelToolUse {
1648 id: "terminal_tool_1".into(),
1649 name: "terminal".into(),
1650 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1651 input: json!({"command": "sleep 1000", "cd": "."}),
1652 is_input_complete: true,
1653 thought_signature: None,
1654 },
1655 ));
1656 fake_model.end_last_completion_stream();
1657
1658 // Wait for the terminal tool to start running
1659 wait_for_terminal_tool_started(&mut events, cx).await;
1660
1661 // Cancel the thread while the terminal is running
1662 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1663
1664 // Collect remaining events, driving the executor to let cancellation complete
1665 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1666
1667 // Verify the terminal was killed
1668 assert!(
1669 handle.was_killed(),
1670 "expected terminal handle to be killed on cancellation"
1671 );
1672
1673 // Verify we got a cancellation stop event
1674 assert_eq!(
1675 stop_events(remaining_events),
1676 vec![acp::StopReason::Cancelled],
1677 );
1678
1679 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1680 thread.update(cx, |thread, _cx| {
1681 let message = thread.last_message().unwrap();
1682 let agent_message = message.as_agent_message().unwrap();
1683
1684 let tool_use = agent_message
1685 .content
1686 .iter()
1687 .find_map(|content| match content {
1688 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1689 _ => None,
1690 })
1691 .expect("expected tool use in agent message");
1692
1693 let tool_result = agent_message
1694 .tool_results
1695 .get(&tool_use.id)
1696 .expect("expected tool result");
1697
1698 let result_text = match &tool_result.content {
1699 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
1700 _ => panic!("expected text content in tool result"),
1701 };
1702
1703 // "partial output" comes from FakeTerminalHandle's output field
1704 assert!(
1705 result_text.contains("partial output"),
1706 "expected tool result to contain terminal output, got: {result_text}"
1707 );
1708 // Match the actual format from process_content in terminal_tool.rs
1709 assert!(
1710 result_text.contains("The user stopped this command"),
1711 "expected tool result to indicate user stopped, got: {result_text}"
1712 );
1713 });
1714
1715 // Verify we can send a new message after cancellation
1716 verify_thread_recovery(&thread, &fake_model, cx).await;
1717}
1718
1719#[gpui::test]
1720async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
1721 // This test verifies that tools which properly handle cancellation via
1722 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
1723 // to cancellation and report that they were cancelled.
1724 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1725 always_allow_tools(cx);
1726 let fake_model = model.as_fake();
1727
1728 let (tool, was_cancelled) = CancellationAwareTool::new();
1729
1730 let mut events = thread
1731 .update(cx, |thread, cx| {
1732 thread.add_tool(tool);
1733 thread.send(
1734 UserMessageId::new(),
1735 ["call the cancellation aware tool"],
1736 cx,
1737 )
1738 })
1739 .unwrap();
1740
1741 cx.run_until_parked();
1742
1743 // Simulate the model calling the cancellation-aware tool
1744 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1745 LanguageModelToolUse {
1746 id: "cancellation_aware_1".into(),
1747 name: "cancellation_aware".into(),
1748 raw_input: r#"{}"#.into(),
1749 input: json!({}),
1750 is_input_complete: true,
1751 thought_signature: None,
1752 },
1753 ));
1754 fake_model.end_last_completion_stream();
1755
1756 cx.run_until_parked();
1757
1758 // Wait for the tool call to be reported
1759 let mut tool_started = false;
1760 let deadline = cx.executor().num_cpus() * 100;
1761 for _ in 0..deadline {
1762 cx.run_until_parked();
1763
1764 while let Some(Some(event)) = events.next().now_or_never() {
1765 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
1766 if tool_call.title == "Cancellation Aware Tool" {
1767 tool_started = true;
1768 break;
1769 }
1770 }
1771 }
1772
1773 if tool_started {
1774 break;
1775 }
1776
1777 cx.background_executor
1778 .timer(Duration::from_millis(10))
1779 .await;
1780 }
1781 assert!(tool_started, "expected cancellation aware tool to start");
1782
1783 // Cancel the thread and wait for it to complete
1784 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
1785
1786 // The cancel task should complete promptly because the tool handles cancellation
1787 let timeout = cx.background_executor.timer(Duration::from_secs(5));
1788 futures::select! {
1789 _ = cancel_task.fuse() => {}
1790 _ = timeout.fuse() => {
1791 panic!("cancel task timed out - tool did not respond to cancellation");
1792 }
1793 }
1794
1795 // Verify the tool detected cancellation via its flag
1796 assert!(
1797 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
1798 "tool should have detected cancellation via event_stream.cancelled_by_user()"
1799 );
1800
1801 // Collect remaining events
1802 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1803
1804 // Verify we got a cancellation stop event
1805 assert_eq!(
1806 stop_events(remaining_events),
1807 vec![acp::StopReason::Cancelled],
1808 );
1809
1810 // Verify we can send a new message after cancellation
1811 verify_thread_recovery(&thread, &fake_model, cx).await;
1812}
1813
1814/// Helper to verify thread can recover after cancellation by sending a simple message.
1815async fn verify_thread_recovery(
1816 thread: &Entity<Thread>,
1817 fake_model: &FakeLanguageModel,
1818 cx: &mut TestAppContext,
1819) {
1820 let events = thread
1821 .update(cx, |thread, cx| {
1822 thread.send(
1823 UserMessageId::new(),
1824 ["Testing: reply with 'Hello' then stop."],
1825 cx,
1826 )
1827 })
1828 .unwrap();
1829 cx.run_until_parked();
1830 fake_model.send_last_completion_stream_text_chunk("Hello");
1831 fake_model
1832 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1833 fake_model.end_last_completion_stream();
1834
1835 let events = events.collect::<Vec<_>>().await;
1836 thread.update(cx, |thread, _cx| {
1837 let message = thread.last_message().unwrap();
1838 let agent_message = message.as_agent_message().unwrap();
1839 assert_eq!(
1840 agent_message.content,
1841 vec![AgentMessageContent::Text("Hello".to_string())]
1842 );
1843 });
1844 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1845}
1846
1847/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
1848async fn wait_for_terminal_tool_started(
1849 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1850 cx: &mut TestAppContext,
1851) {
1852 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
1853 for _ in 0..deadline {
1854 cx.run_until_parked();
1855
1856 while let Some(Some(event)) = events.next().now_or_never() {
1857 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1858 update,
1859 ))) = &event
1860 {
1861 if update.fields.content.as_ref().is_some_and(|content| {
1862 content
1863 .iter()
1864 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1865 }) {
1866 return;
1867 }
1868 }
1869 }
1870
1871 cx.background_executor
1872 .timer(Duration::from_millis(10))
1873 .await;
1874 }
1875 panic!("terminal tool did not start within the expected time");
1876}
1877
1878/// Collects events until a Stop event is received, driving the executor to completion.
1879async fn collect_events_until_stop(
1880 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1881 cx: &mut TestAppContext,
1882) -> Vec<Result<ThreadEvent>> {
1883 let mut collected = Vec::new();
1884 let deadline = cx.executor().num_cpus() * 200;
1885
1886 for _ in 0..deadline {
1887 cx.executor().advance_clock(Duration::from_millis(10));
1888 cx.run_until_parked();
1889
1890 while let Some(Some(event)) = events.next().now_or_never() {
1891 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
1892 collected.push(event);
1893 if is_stop {
1894 return collected;
1895 }
1896 }
1897 }
1898 panic!(
1899 "did not receive Stop event within the expected time; collected {} events",
1900 collected.len()
1901 );
1902}
1903
1904#[gpui::test]
1905async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
1906 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1907 always_allow_tools(cx);
1908 let fake_model = model.as_fake();
1909
1910 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1911 let environment = Rc::new(FakeThreadEnvironment {
1912 handle: handle.clone(),
1913 });
1914
1915 let message_id = UserMessageId::new();
1916 let mut events = thread
1917 .update(cx, |thread, cx| {
1918 thread.add_tool(crate::TerminalTool::new(
1919 thread.project().clone(),
1920 environment,
1921 ));
1922 thread.send(message_id.clone(), ["run a command"], cx)
1923 })
1924 .unwrap();
1925
1926 cx.run_until_parked();
1927
1928 // Simulate the model calling the terminal tool
1929 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1930 LanguageModelToolUse {
1931 id: "terminal_tool_1".into(),
1932 name: "terminal".into(),
1933 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1934 input: json!({"command": "sleep 1000", "cd": "."}),
1935 is_input_complete: true,
1936 thought_signature: None,
1937 },
1938 ));
1939 fake_model.end_last_completion_stream();
1940
1941 // Wait for the terminal tool to start running
1942 wait_for_terminal_tool_started(&mut events, cx).await;
1943
1944 // Truncate the thread while the terminal is running
1945 thread
1946 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1947 .unwrap();
1948
1949 // Drive the executor to let cancellation complete
1950 let _ = collect_events_until_stop(&mut events, cx).await;
1951
1952 // Verify the terminal was killed
1953 assert!(
1954 handle.was_killed(),
1955 "expected terminal handle to be killed on truncate"
1956 );
1957
1958 // Verify the thread is empty after truncation
1959 thread.update(cx, |thread, _cx| {
1960 assert_eq!(
1961 thread.to_markdown(),
1962 "",
1963 "expected thread to be empty after truncating the only message"
1964 );
1965 });
1966
1967 // Verify we can send a new message after truncation
1968 verify_thread_recovery(&thread, &fake_model, cx).await;
1969}
1970
1971#[gpui::test]
1972async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
1973 // Tests that cancellation properly kills all running terminal tools when multiple are active.
1974 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1975 always_allow_tools(cx);
1976 let fake_model = model.as_fake();
1977
1978 let environment = Rc::new(MultiTerminalEnvironment::new());
1979
1980 let mut events = thread
1981 .update(cx, |thread, cx| {
1982 thread.add_tool(crate::TerminalTool::new(
1983 thread.project().clone(),
1984 environment.clone(),
1985 ));
1986 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
1987 })
1988 .unwrap();
1989
1990 cx.run_until_parked();
1991
1992 // Simulate the model calling two terminal tools
1993 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1994 LanguageModelToolUse {
1995 id: "terminal_tool_1".into(),
1996 name: "terminal".into(),
1997 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1998 input: json!({"command": "sleep 1000", "cd": "."}),
1999 is_input_complete: true,
2000 thought_signature: None,
2001 },
2002 ));
2003 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2004 LanguageModelToolUse {
2005 id: "terminal_tool_2".into(),
2006 name: "terminal".into(),
2007 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2008 input: json!({"command": "sleep 2000", "cd": "."}),
2009 is_input_complete: true,
2010 thought_signature: None,
2011 },
2012 ));
2013 fake_model.end_last_completion_stream();
2014
2015 // Wait for both terminal tools to start by counting terminal content updates
2016 let mut terminals_started = 0;
2017 let deadline = cx.executor().num_cpus() * 100;
2018 for _ in 0..deadline {
2019 cx.run_until_parked();
2020
2021 while let Some(Some(event)) = events.next().now_or_never() {
2022 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2023 update,
2024 ))) = &event
2025 {
2026 if update.fields.content.as_ref().is_some_and(|content| {
2027 content
2028 .iter()
2029 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2030 }) {
2031 terminals_started += 1;
2032 if terminals_started >= 2 {
2033 break;
2034 }
2035 }
2036 }
2037 }
2038 if terminals_started >= 2 {
2039 break;
2040 }
2041
2042 cx.background_executor
2043 .timer(Duration::from_millis(10))
2044 .await;
2045 }
2046 assert!(
2047 terminals_started >= 2,
2048 "expected 2 terminal tools to start, got {terminals_started}"
2049 );
2050
2051 // Cancel the thread while both terminals are running
2052 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2053
2054 // Collect remaining events
2055 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2056
2057 // Verify both terminal handles were killed
2058 let handles = environment.handles();
2059 assert_eq!(
2060 handles.len(),
2061 2,
2062 "expected 2 terminal handles to be created"
2063 );
2064 assert!(
2065 handles[0].was_killed(),
2066 "expected first terminal handle to be killed on cancellation"
2067 );
2068 assert!(
2069 handles[1].was_killed(),
2070 "expected second terminal handle to be killed on cancellation"
2071 );
2072
2073 // Verify we got a cancellation stop event
2074 assert_eq!(
2075 stop_events(remaining_events),
2076 vec![acp::StopReason::Cancelled],
2077 );
2078}
2079
2080#[gpui::test]
2081async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2082 // Tests that clicking the stop button on the terminal card (as opposed to the main
2083 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2084 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2085 always_allow_tools(cx);
2086 let fake_model = model.as_fake();
2087
2088 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2089 let environment = Rc::new(FakeThreadEnvironment {
2090 handle: handle.clone(),
2091 });
2092
2093 let mut events = thread
2094 .update(cx, |thread, cx| {
2095 thread.add_tool(crate::TerminalTool::new(
2096 thread.project().clone(),
2097 environment,
2098 ));
2099 thread.send(UserMessageId::new(), ["run a command"], cx)
2100 })
2101 .unwrap();
2102
2103 cx.run_until_parked();
2104
2105 // Simulate the model calling the terminal tool
2106 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2107 LanguageModelToolUse {
2108 id: "terminal_tool_1".into(),
2109 name: "terminal".into(),
2110 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2111 input: json!({"command": "sleep 1000", "cd": "."}),
2112 is_input_complete: true,
2113 thought_signature: None,
2114 },
2115 ));
2116 fake_model.end_last_completion_stream();
2117
2118 // Wait for the terminal tool to start running
2119 wait_for_terminal_tool_started(&mut events, cx).await;
2120
2121 // Simulate user clicking stop on the terminal card itself.
2122 // This sets the flag and signals exit (simulating what the real UI would do).
2123 handle.set_stopped_by_user(true);
2124 handle.killed.store(true, Ordering::SeqCst);
2125 handle.signal_exit();
2126
2127 // Wait for the tool to complete
2128 cx.run_until_parked();
2129
2130 // The thread continues after tool completion - simulate the model ending its turn
2131 fake_model
2132 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2133 fake_model.end_last_completion_stream();
2134
2135 // Collect remaining events
2136 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2137
2138 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2139 assert_eq!(
2140 stop_events(remaining_events),
2141 vec![acp::StopReason::EndTurn],
2142 );
2143
2144 // Verify the tool result indicates user stopped
2145 thread.update(cx, |thread, _cx| {
2146 let message = thread.last_message().unwrap();
2147 let agent_message = message.as_agent_message().unwrap();
2148
2149 let tool_use = agent_message
2150 .content
2151 .iter()
2152 .find_map(|content| match content {
2153 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2154 _ => None,
2155 })
2156 .expect("expected tool use in agent message");
2157
2158 let tool_result = agent_message
2159 .tool_results
2160 .get(&tool_use.id)
2161 .expect("expected tool result");
2162
2163 let result_text = match &tool_result.content {
2164 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2165 _ => panic!("expected text content in tool result"),
2166 };
2167
2168 assert!(
2169 result_text.contains("The user stopped this command"),
2170 "expected tool result to indicate user stopped, got: {result_text}"
2171 );
2172 });
2173}
2174
2175#[gpui::test]
2176async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2177 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2178 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2179 always_allow_tools(cx);
2180 let fake_model = model.as_fake();
2181
2182 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2183 let environment = Rc::new(FakeThreadEnvironment {
2184 handle: handle.clone(),
2185 });
2186
2187 let mut events = thread
2188 .update(cx, |thread, cx| {
2189 thread.add_tool(crate::TerminalTool::new(
2190 thread.project().clone(),
2191 environment,
2192 ));
2193 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2194 })
2195 .unwrap();
2196
2197 cx.run_until_parked();
2198
2199 // Simulate the model calling the terminal tool with a short timeout
2200 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2201 LanguageModelToolUse {
2202 id: "terminal_tool_1".into(),
2203 name: "terminal".into(),
2204 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2205 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2206 is_input_complete: true,
2207 thought_signature: None,
2208 },
2209 ));
2210 fake_model.end_last_completion_stream();
2211
2212 // Wait for the terminal tool to start running
2213 wait_for_terminal_tool_started(&mut events, cx).await;
2214
2215 // Advance clock past the timeout
2216 cx.executor().advance_clock(Duration::from_millis(200));
2217 cx.run_until_parked();
2218
2219 // The thread continues after tool completion - simulate the model ending its turn
2220 fake_model
2221 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2222 fake_model.end_last_completion_stream();
2223
2224 // Collect remaining events
2225 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2226
2227 // Verify the terminal was killed due to timeout
2228 assert!(
2229 handle.was_killed(),
2230 "expected terminal handle to be killed on timeout"
2231 );
2232
2233 // Verify we got an EndTurn (the tool completed, just with timeout)
2234 assert_eq!(
2235 stop_events(remaining_events),
2236 vec![acp::StopReason::EndTurn],
2237 );
2238
2239 // Verify the tool result indicates timeout, not user stopped
2240 thread.update(cx, |thread, _cx| {
2241 let message = thread.last_message().unwrap();
2242 let agent_message = message.as_agent_message().unwrap();
2243
2244 let tool_use = agent_message
2245 .content
2246 .iter()
2247 .find_map(|content| match content {
2248 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2249 _ => None,
2250 })
2251 .expect("expected tool use in agent message");
2252
2253 let tool_result = agent_message
2254 .tool_results
2255 .get(&tool_use.id)
2256 .expect("expected tool result");
2257
2258 let result_text = match &tool_result.content {
2259 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2260 _ => panic!("expected text content in tool result"),
2261 };
2262
2263 assert!(
2264 result_text.contains("timed out"),
2265 "expected tool result to indicate timeout, got: {result_text}"
2266 );
2267 assert!(
2268 !result_text.contains("The user stopped"),
2269 "tool result should not mention user stopped when it timed out, got: {result_text}"
2270 );
2271 });
2272}
2273
2274#[gpui::test]
2275async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2276 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2277 let fake_model = model.as_fake();
2278
2279 let events_1 = thread
2280 .update(cx, |thread, cx| {
2281 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2282 })
2283 .unwrap();
2284 cx.run_until_parked();
2285 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2286 cx.run_until_parked();
2287
2288 let events_2 = thread
2289 .update(cx, |thread, cx| {
2290 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2291 })
2292 .unwrap();
2293 cx.run_until_parked();
2294 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2295 fake_model
2296 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2297 fake_model.end_last_completion_stream();
2298
2299 let events_1 = events_1.collect::<Vec<_>>().await;
2300 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2301 let events_2 = events_2.collect::<Vec<_>>().await;
2302 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2303}
2304
2305#[gpui::test]
2306async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2307 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2308 let fake_model = model.as_fake();
2309
2310 let events_1 = thread
2311 .update(cx, |thread, cx| {
2312 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2313 })
2314 .unwrap();
2315 cx.run_until_parked();
2316 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2317 fake_model
2318 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2319 fake_model.end_last_completion_stream();
2320 let events_1 = events_1.collect::<Vec<_>>().await;
2321
2322 let events_2 = thread
2323 .update(cx, |thread, cx| {
2324 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2325 })
2326 .unwrap();
2327 cx.run_until_parked();
2328 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2329 fake_model
2330 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2331 fake_model.end_last_completion_stream();
2332 let events_2 = events_2.collect::<Vec<_>>().await;
2333
2334 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2335 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2336}
2337
2338#[gpui::test]
2339async fn test_refusal(cx: &mut TestAppContext) {
2340 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2341 let fake_model = model.as_fake();
2342
2343 let events = thread
2344 .update(cx, |thread, cx| {
2345 thread.send(UserMessageId::new(), ["Hello"], cx)
2346 })
2347 .unwrap();
2348 cx.run_until_parked();
2349 thread.read_with(cx, |thread, _| {
2350 assert_eq!(
2351 thread.to_markdown(),
2352 indoc! {"
2353 ## User
2354
2355 Hello
2356 "}
2357 );
2358 });
2359
2360 fake_model.send_last_completion_stream_text_chunk("Hey!");
2361 cx.run_until_parked();
2362 thread.read_with(cx, |thread, _| {
2363 assert_eq!(
2364 thread.to_markdown(),
2365 indoc! {"
2366 ## User
2367
2368 Hello
2369
2370 ## Assistant
2371
2372 Hey!
2373 "}
2374 );
2375 });
2376
2377 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2378 fake_model
2379 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2380 let events = events.collect::<Vec<_>>().await;
2381 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2382 thread.read_with(cx, |thread, _| {
2383 assert_eq!(thread.to_markdown(), "");
2384 });
2385}
2386
2387#[gpui::test]
2388async fn test_truncate_first_message(cx: &mut TestAppContext) {
2389 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2390 let fake_model = model.as_fake();
2391
2392 let message_id = UserMessageId::new();
2393 thread
2394 .update(cx, |thread, cx| {
2395 thread.send(message_id.clone(), ["Hello"], cx)
2396 })
2397 .unwrap();
2398 cx.run_until_parked();
2399 thread.read_with(cx, |thread, _| {
2400 assert_eq!(
2401 thread.to_markdown(),
2402 indoc! {"
2403 ## User
2404
2405 Hello
2406 "}
2407 );
2408 assert_eq!(thread.latest_token_usage(), None);
2409 });
2410
2411 fake_model.send_last_completion_stream_text_chunk("Hey!");
2412 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2413 language_model::TokenUsage {
2414 input_tokens: 32_000,
2415 output_tokens: 16_000,
2416 cache_creation_input_tokens: 0,
2417 cache_read_input_tokens: 0,
2418 },
2419 ));
2420 cx.run_until_parked();
2421 thread.read_with(cx, |thread, _| {
2422 assert_eq!(
2423 thread.to_markdown(),
2424 indoc! {"
2425 ## User
2426
2427 Hello
2428
2429 ## Assistant
2430
2431 Hey!
2432 "}
2433 );
2434 assert_eq!(
2435 thread.latest_token_usage(),
2436 Some(acp_thread::TokenUsage {
2437 used_tokens: 32_000 + 16_000,
2438 max_tokens: 1_000_000,
2439 input_tokens: 32_000,
2440 output_tokens: 16_000,
2441 })
2442 );
2443 });
2444
2445 thread
2446 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2447 .unwrap();
2448 cx.run_until_parked();
2449 thread.read_with(cx, |thread, _| {
2450 assert_eq!(thread.to_markdown(), "");
2451 assert_eq!(thread.latest_token_usage(), None);
2452 });
2453
2454 // Ensure we can still send a new message after truncation.
2455 thread
2456 .update(cx, |thread, cx| {
2457 thread.send(UserMessageId::new(), ["Hi"], cx)
2458 })
2459 .unwrap();
2460 thread.update(cx, |thread, _cx| {
2461 assert_eq!(
2462 thread.to_markdown(),
2463 indoc! {"
2464 ## User
2465
2466 Hi
2467 "}
2468 );
2469 });
2470 cx.run_until_parked();
2471 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2472 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2473 language_model::TokenUsage {
2474 input_tokens: 40_000,
2475 output_tokens: 20_000,
2476 cache_creation_input_tokens: 0,
2477 cache_read_input_tokens: 0,
2478 },
2479 ));
2480 cx.run_until_parked();
2481 thread.read_with(cx, |thread, _| {
2482 assert_eq!(
2483 thread.to_markdown(),
2484 indoc! {"
2485 ## User
2486
2487 Hi
2488
2489 ## Assistant
2490
2491 Ahoy!
2492 "}
2493 );
2494
2495 assert_eq!(
2496 thread.latest_token_usage(),
2497 Some(acp_thread::TokenUsage {
2498 used_tokens: 40_000 + 20_000,
2499 max_tokens: 1_000_000,
2500 input_tokens: 40_000,
2501 output_tokens: 20_000,
2502 })
2503 );
2504 });
2505}
2506
2507#[gpui::test]
2508async fn test_truncate_second_message(cx: &mut TestAppContext) {
2509 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2510 let fake_model = model.as_fake();
2511
2512 thread
2513 .update(cx, |thread, cx| {
2514 thread.send(UserMessageId::new(), ["Message 1"], cx)
2515 })
2516 .unwrap();
2517 cx.run_until_parked();
2518 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2519 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2520 language_model::TokenUsage {
2521 input_tokens: 32_000,
2522 output_tokens: 16_000,
2523 cache_creation_input_tokens: 0,
2524 cache_read_input_tokens: 0,
2525 },
2526 ));
2527 fake_model.end_last_completion_stream();
2528 cx.run_until_parked();
2529
2530 let assert_first_message_state = |cx: &mut TestAppContext| {
2531 thread.clone().read_with(cx, |thread, _| {
2532 assert_eq!(
2533 thread.to_markdown(),
2534 indoc! {"
2535 ## User
2536
2537 Message 1
2538
2539 ## Assistant
2540
2541 Message 1 response
2542 "}
2543 );
2544
2545 assert_eq!(
2546 thread.latest_token_usage(),
2547 Some(acp_thread::TokenUsage {
2548 used_tokens: 32_000 + 16_000,
2549 max_tokens: 1_000_000,
2550 input_tokens: 32_000,
2551 output_tokens: 16_000,
2552 })
2553 );
2554 });
2555 };
2556
2557 assert_first_message_state(cx);
2558
2559 let second_message_id = UserMessageId::new();
2560 thread
2561 .update(cx, |thread, cx| {
2562 thread.send(second_message_id.clone(), ["Message 2"], cx)
2563 })
2564 .unwrap();
2565 cx.run_until_parked();
2566
2567 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2568 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2569 language_model::TokenUsage {
2570 input_tokens: 40_000,
2571 output_tokens: 20_000,
2572 cache_creation_input_tokens: 0,
2573 cache_read_input_tokens: 0,
2574 },
2575 ));
2576 fake_model.end_last_completion_stream();
2577 cx.run_until_parked();
2578
2579 thread.read_with(cx, |thread, _| {
2580 assert_eq!(
2581 thread.to_markdown(),
2582 indoc! {"
2583 ## User
2584
2585 Message 1
2586
2587 ## Assistant
2588
2589 Message 1 response
2590
2591 ## User
2592
2593 Message 2
2594
2595 ## Assistant
2596
2597 Message 2 response
2598 "}
2599 );
2600
2601 assert_eq!(
2602 thread.latest_token_usage(),
2603 Some(acp_thread::TokenUsage {
2604 used_tokens: 40_000 + 20_000,
2605 max_tokens: 1_000_000,
2606 input_tokens: 40_000,
2607 output_tokens: 20_000,
2608 })
2609 );
2610 });
2611
2612 thread
2613 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2614 .unwrap();
2615 cx.run_until_parked();
2616
2617 assert_first_message_state(cx);
2618}
2619
2620#[gpui::test]
2621async fn test_title_generation(cx: &mut TestAppContext) {
2622 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2623 let fake_model = model.as_fake();
2624
2625 let summary_model = Arc::new(FakeLanguageModel::default());
2626 thread.update(cx, |thread, cx| {
2627 thread.set_summarization_model(Some(summary_model.clone()), cx)
2628 });
2629
2630 let send = thread
2631 .update(cx, |thread, cx| {
2632 thread.send(UserMessageId::new(), ["Hello"], cx)
2633 })
2634 .unwrap();
2635 cx.run_until_parked();
2636
2637 fake_model.send_last_completion_stream_text_chunk("Hey!");
2638 fake_model.end_last_completion_stream();
2639 cx.run_until_parked();
2640 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2641
2642 // Ensure the summary model has been invoked to generate a title.
2643 summary_model.send_last_completion_stream_text_chunk("Hello ");
2644 summary_model.send_last_completion_stream_text_chunk("world\nG");
2645 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2646 summary_model.end_last_completion_stream();
2647 send.collect::<Vec<_>>().await;
2648 cx.run_until_parked();
2649 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2650
2651 // Send another message, ensuring no title is generated this time.
2652 let send = thread
2653 .update(cx, |thread, cx| {
2654 thread.send(UserMessageId::new(), ["Hello again"], cx)
2655 })
2656 .unwrap();
2657 cx.run_until_parked();
2658 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2659 fake_model.end_last_completion_stream();
2660 cx.run_until_parked();
2661 assert_eq!(summary_model.pending_completions(), Vec::new());
2662 send.collect::<Vec<_>>().await;
2663 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2664}
2665
2666#[gpui::test]
2667async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2668 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2669 let fake_model = model.as_fake();
2670
2671 let _events = thread
2672 .update(cx, |thread, cx| {
2673 thread.add_tool(ToolRequiringPermission);
2674 thread.add_tool(EchoTool);
2675 thread.send(UserMessageId::new(), ["Hey!"], cx)
2676 })
2677 .unwrap();
2678 cx.run_until_parked();
2679
2680 let permission_tool_use = LanguageModelToolUse {
2681 id: "tool_id_1".into(),
2682 name: ToolRequiringPermission::name().into(),
2683 raw_input: "{}".into(),
2684 input: json!({}),
2685 is_input_complete: true,
2686 thought_signature: None,
2687 };
2688 let echo_tool_use = LanguageModelToolUse {
2689 id: "tool_id_2".into(),
2690 name: EchoTool::name().into(),
2691 raw_input: json!({"text": "test"}).to_string(),
2692 input: json!({"text": "test"}),
2693 is_input_complete: true,
2694 thought_signature: None,
2695 };
2696 fake_model.send_last_completion_stream_text_chunk("Hi!");
2697 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2698 permission_tool_use,
2699 ));
2700 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2701 echo_tool_use.clone(),
2702 ));
2703 fake_model.end_last_completion_stream();
2704 cx.run_until_parked();
2705
2706 // Ensure pending tools are skipped when building a request.
2707 let request = thread
2708 .read_with(cx, |thread, cx| {
2709 thread.build_completion_request(CompletionIntent::EditFile, cx)
2710 })
2711 .unwrap();
2712 assert_eq!(
2713 request.messages[1..],
2714 vec![
2715 LanguageModelRequestMessage {
2716 role: Role::User,
2717 content: vec!["Hey!".into()],
2718 cache: true,
2719 reasoning_details: None,
2720 },
2721 LanguageModelRequestMessage {
2722 role: Role::Assistant,
2723 content: vec![
2724 MessageContent::Text("Hi!".into()),
2725 MessageContent::ToolUse(echo_tool_use.clone())
2726 ],
2727 cache: false,
2728 reasoning_details: None,
2729 },
2730 LanguageModelRequestMessage {
2731 role: Role::User,
2732 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2733 tool_use_id: echo_tool_use.id.clone(),
2734 tool_name: echo_tool_use.name,
2735 is_error: false,
2736 content: "test".into(),
2737 output: Some("test".into())
2738 })],
2739 cache: false,
2740 reasoning_details: None,
2741 },
2742 ],
2743 );
2744}
2745
2746#[gpui::test]
2747async fn test_agent_connection(cx: &mut TestAppContext) {
2748 cx.update(settings::init);
2749 let templates = Templates::new();
2750
2751 // Initialize language model system with test provider
2752 cx.update(|cx| {
2753 gpui_tokio::init(cx);
2754
2755 let http_client = FakeHttpClient::with_404_response();
2756 let clock = Arc::new(clock::FakeSystemClock::new());
2757 let client = Client::new(clock, http_client, cx);
2758 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2759 language_model::init(client.clone(), cx);
2760 language_models::init(user_store, client.clone(), cx);
2761 LanguageModelRegistry::test(cx);
2762 });
2763 cx.executor().forbid_parking();
2764
2765 // Create a project for new_thread
2766 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2767 fake_fs.insert_tree(path!("/test"), json!({})).await;
2768 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2769 let cwd = Path::new("/test");
2770 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2771
2772 // Create agent and connection
2773 let agent = NativeAgent::new(
2774 project.clone(),
2775 thread_store,
2776 templates.clone(),
2777 None,
2778 fake_fs.clone(),
2779 &mut cx.to_async(),
2780 )
2781 .await
2782 .unwrap();
2783 let connection = NativeAgentConnection(agent.clone());
2784
2785 // Create a thread using new_thread
2786 let connection_rc = Rc::new(connection.clone());
2787 let acp_thread = cx
2788 .update(|cx| connection_rc.new_thread(project, cwd, cx))
2789 .await
2790 .expect("new_thread should succeed");
2791
2792 // Get the session_id from the AcpThread
2793 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2794
2795 // Test model_selector returns Some
2796 let selector_opt = connection.model_selector(&session_id);
2797 assert!(
2798 selector_opt.is_some(),
2799 "agent should always support ModelSelector"
2800 );
2801 let selector = selector_opt.unwrap();
2802
2803 // Test list_models
2804 let listed_models = cx
2805 .update(|cx| selector.list_models(cx))
2806 .await
2807 .expect("list_models should succeed");
2808 let AgentModelList::Grouped(listed_models) = listed_models else {
2809 panic!("Unexpected model list type");
2810 };
2811 assert!(!listed_models.is_empty(), "should have at least one model");
2812 assert_eq!(
2813 listed_models[&AgentModelGroupName("Fake".into())][0]
2814 .id
2815 .0
2816 .as_ref(),
2817 "fake/fake"
2818 );
2819
2820 // Test selected_model returns the default
2821 let model = cx
2822 .update(|cx| selector.selected_model(cx))
2823 .await
2824 .expect("selected_model should succeed");
2825 let model = cx
2826 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2827 .unwrap();
2828 let model = model.as_fake();
2829 assert_eq!(model.id().0, "fake", "should return default model");
2830
2831 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2832 cx.run_until_parked();
2833 model.send_last_completion_stream_text_chunk("def");
2834 cx.run_until_parked();
2835 acp_thread.read_with(cx, |thread, cx| {
2836 assert_eq!(
2837 thread.to_markdown(cx),
2838 indoc! {"
2839 ## User
2840
2841 abc
2842
2843 ## Assistant
2844
2845 def
2846
2847 "}
2848 )
2849 });
2850
2851 // Test cancel
2852 cx.update(|cx| connection.cancel(&session_id, cx));
2853 request.await.expect("prompt should fail gracefully");
2854
2855 // Ensure that dropping the ACP thread causes the native thread to be
2856 // dropped as well.
2857 cx.update(|_| drop(acp_thread));
2858 let result = cx
2859 .update(|cx| {
2860 connection.prompt(
2861 Some(acp_thread::UserMessageId::new()),
2862 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2863 cx,
2864 )
2865 })
2866 .await;
2867 assert_eq!(
2868 result.as_ref().unwrap_err().to_string(),
2869 "Session not found",
2870 "unexpected result: {:?}",
2871 result
2872 );
2873}
2874
2875#[gpui::test]
2876async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2877 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2878 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2879 let fake_model = model.as_fake();
2880
2881 let mut events = thread
2882 .update(cx, |thread, cx| {
2883 thread.send(UserMessageId::new(), ["Think"], cx)
2884 })
2885 .unwrap();
2886 cx.run_until_parked();
2887
2888 // Simulate streaming partial input.
2889 let input = json!({});
2890 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2891 LanguageModelToolUse {
2892 id: "1".into(),
2893 name: ThinkingTool::name().into(),
2894 raw_input: input.to_string(),
2895 input,
2896 is_input_complete: false,
2897 thought_signature: None,
2898 },
2899 ));
2900
2901 // Input streaming completed
2902 let input = json!({ "content": "Thinking hard!" });
2903 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2904 LanguageModelToolUse {
2905 id: "1".into(),
2906 name: "thinking".into(),
2907 raw_input: input.to_string(),
2908 input,
2909 is_input_complete: true,
2910 thought_signature: None,
2911 },
2912 ));
2913 fake_model.end_last_completion_stream();
2914 cx.run_until_parked();
2915
2916 let tool_call = expect_tool_call(&mut events).await;
2917 assert_eq!(
2918 tool_call,
2919 acp::ToolCall::new("1", "Thinking")
2920 .kind(acp::ToolKind::Think)
2921 .raw_input(json!({}))
2922 .meta(acp::Meta::from_iter([(
2923 "tool_name".into(),
2924 "thinking".into()
2925 )]))
2926 );
2927 let update = expect_tool_call_update_fields(&mut events).await;
2928 assert_eq!(
2929 update,
2930 acp::ToolCallUpdate::new(
2931 "1",
2932 acp::ToolCallUpdateFields::new()
2933 .title("Thinking")
2934 .kind(acp::ToolKind::Think)
2935 .raw_input(json!({ "content": "Thinking hard!"}))
2936 )
2937 );
2938 let update = expect_tool_call_update_fields(&mut events).await;
2939 assert_eq!(
2940 update,
2941 acp::ToolCallUpdate::new(
2942 "1",
2943 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
2944 )
2945 );
2946 let update = expect_tool_call_update_fields(&mut events).await;
2947 assert_eq!(
2948 update,
2949 acp::ToolCallUpdate::new(
2950 "1",
2951 acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
2952 )
2953 );
2954 let update = expect_tool_call_update_fields(&mut events).await;
2955 assert_eq!(
2956 update,
2957 acp::ToolCallUpdate::new(
2958 "1",
2959 acp::ToolCallUpdateFields::new()
2960 .status(acp::ToolCallStatus::Completed)
2961 .raw_output("Finished thinking.")
2962 )
2963 );
2964}
2965
2966#[gpui::test]
2967async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2968 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2969 let fake_model = model.as_fake();
2970
2971 let mut events = thread
2972 .update(cx, |thread, cx| {
2973 thread.send(UserMessageId::new(), ["Hello!"], cx)
2974 })
2975 .unwrap();
2976 cx.run_until_parked();
2977
2978 fake_model.send_last_completion_stream_text_chunk("Hey!");
2979 fake_model.end_last_completion_stream();
2980
2981 let mut retry_events = Vec::new();
2982 while let Some(Ok(event)) = events.next().await {
2983 match event {
2984 ThreadEvent::Retry(retry_status) => {
2985 retry_events.push(retry_status);
2986 }
2987 ThreadEvent::Stop(..) => break,
2988 _ => {}
2989 }
2990 }
2991
2992 assert_eq!(retry_events.len(), 0);
2993 thread.read_with(cx, |thread, _cx| {
2994 assert_eq!(
2995 thread.to_markdown(),
2996 indoc! {"
2997 ## User
2998
2999 Hello!
3000
3001 ## Assistant
3002
3003 Hey!
3004 "}
3005 )
3006 });
3007}
3008
3009#[gpui::test]
3010async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3011 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3012 let fake_model = model.as_fake();
3013
3014 let mut events = thread
3015 .update(cx, |thread, cx| {
3016 thread.send(UserMessageId::new(), ["Hello!"], cx)
3017 })
3018 .unwrap();
3019 cx.run_until_parked();
3020
3021 fake_model.send_last_completion_stream_text_chunk("Hey,");
3022 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3023 provider: LanguageModelProviderName::new("Anthropic"),
3024 retry_after: Some(Duration::from_secs(3)),
3025 });
3026 fake_model.end_last_completion_stream();
3027
3028 cx.executor().advance_clock(Duration::from_secs(3));
3029 cx.run_until_parked();
3030
3031 fake_model.send_last_completion_stream_text_chunk("there!");
3032 fake_model.end_last_completion_stream();
3033 cx.run_until_parked();
3034
3035 let mut retry_events = Vec::new();
3036 while let Some(Ok(event)) = events.next().await {
3037 match event {
3038 ThreadEvent::Retry(retry_status) => {
3039 retry_events.push(retry_status);
3040 }
3041 ThreadEvent::Stop(..) => break,
3042 _ => {}
3043 }
3044 }
3045
3046 assert_eq!(retry_events.len(), 1);
3047 assert!(matches!(
3048 retry_events[0],
3049 acp_thread::RetryStatus { attempt: 1, .. }
3050 ));
3051 thread.read_with(cx, |thread, _cx| {
3052 assert_eq!(
3053 thread.to_markdown(),
3054 indoc! {"
3055 ## User
3056
3057 Hello!
3058
3059 ## Assistant
3060
3061 Hey,
3062
3063 [resume]
3064
3065 ## Assistant
3066
3067 there!
3068 "}
3069 )
3070 });
3071}
3072
3073#[gpui::test]
3074async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3075 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3076 let fake_model = model.as_fake();
3077
3078 let events = thread
3079 .update(cx, |thread, cx| {
3080 thread.add_tool(EchoTool);
3081 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3082 })
3083 .unwrap();
3084 cx.run_until_parked();
3085
3086 let tool_use_1 = LanguageModelToolUse {
3087 id: "tool_1".into(),
3088 name: EchoTool::name().into(),
3089 raw_input: json!({"text": "test"}).to_string(),
3090 input: json!({"text": "test"}),
3091 is_input_complete: true,
3092 thought_signature: None,
3093 };
3094 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3095 tool_use_1.clone(),
3096 ));
3097 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3098 provider: LanguageModelProviderName::new("Anthropic"),
3099 retry_after: Some(Duration::from_secs(3)),
3100 });
3101 fake_model.end_last_completion_stream();
3102
3103 cx.executor().advance_clock(Duration::from_secs(3));
3104 let completion = fake_model.pending_completions().pop().unwrap();
3105 assert_eq!(
3106 completion.messages[1..],
3107 vec![
3108 LanguageModelRequestMessage {
3109 role: Role::User,
3110 content: vec!["Call the echo tool!".into()],
3111 cache: false,
3112 reasoning_details: None,
3113 },
3114 LanguageModelRequestMessage {
3115 role: Role::Assistant,
3116 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3117 cache: false,
3118 reasoning_details: None,
3119 },
3120 LanguageModelRequestMessage {
3121 role: Role::User,
3122 content: vec![language_model::MessageContent::ToolResult(
3123 LanguageModelToolResult {
3124 tool_use_id: tool_use_1.id.clone(),
3125 tool_name: tool_use_1.name.clone(),
3126 is_error: false,
3127 content: "test".into(),
3128 output: Some("test".into())
3129 }
3130 )],
3131 cache: true,
3132 reasoning_details: None,
3133 },
3134 ]
3135 );
3136
3137 fake_model.send_last_completion_stream_text_chunk("Done");
3138 fake_model.end_last_completion_stream();
3139 cx.run_until_parked();
3140 events.collect::<Vec<_>>().await;
3141 thread.read_with(cx, |thread, _cx| {
3142 assert_eq!(
3143 thread.last_message(),
3144 Some(Message::Agent(AgentMessage {
3145 content: vec![AgentMessageContent::Text("Done".into())],
3146 tool_results: IndexMap::default(),
3147 reasoning_details: None,
3148 }))
3149 );
3150 })
3151}
3152
3153#[gpui::test]
3154async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3155 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3156 let fake_model = model.as_fake();
3157
3158 let mut events = thread
3159 .update(cx, |thread, cx| {
3160 thread.send(UserMessageId::new(), ["Hello!"], cx)
3161 })
3162 .unwrap();
3163 cx.run_until_parked();
3164
3165 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3166 fake_model.send_last_completion_stream_error(
3167 LanguageModelCompletionError::ServerOverloaded {
3168 provider: LanguageModelProviderName::new("Anthropic"),
3169 retry_after: Some(Duration::from_secs(3)),
3170 },
3171 );
3172 fake_model.end_last_completion_stream();
3173 cx.executor().advance_clock(Duration::from_secs(3));
3174 cx.run_until_parked();
3175 }
3176
3177 let mut errors = Vec::new();
3178 let mut retry_events = Vec::new();
3179 while let Some(event) = events.next().await {
3180 match event {
3181 Ok(ThreadEvent::Retry(retry_status)) => {
3182 retry_events.push(retry_status);
3183 }
3184 Ok(ThreadEvent::Stop(..)) => break,
3185 Err(error) => errors.push(error),
3186 _ => {}
3187 }
3188 }
3189
3190 assert_eq!(
3191 retry_events.len(),
3192 crate::thread::MAX_RETRY_ATTEMPTS as usize
3193 );
3194 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3195 assert_eq!(retry_events[i].attempt, i + 1);
3196 }
3197 assert_eq!(errors.len(), 1);
3198 let error = errors[0]
3199 .downcast_ref::<LanguageModelCompletionError>()
3200 .unwrap();
3201 assert!(matches!(
3202 error,
3203 LanguageModelCompletionError::ServerOverloaded { .. }
3204 ));
3205}
3206
3207/// Filters out the stop events for asserting against in tests
3208fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3209 result_events
3210 .into_iter()
3211 .filter_map(|event| match event.unwrap() {
3212 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3213 _ => None,
3214 })
3215 .collect()
3216}
3217
3218struct ThreadTest {
3219 model: Arc<dyn LanguageModel>,
3220 thread: Entity<Thread>,
3221 project_context: Entity<ProjectContext>,
3222 context_server_store: Entity<ContextServerStore>,
3223 fs: Arc<FakeFs>,
3224}
3225
3226enum TestModel {
3227 Sonnet4,
3228 Fake,
3229}
3230
3231impl TestModel {
3232 fn id(&self) -> LanguageModelId {
3233 match self {
3234 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3235 TestModel::Fake => unreachable!(),
3236 }
3237 }
3238}
3239
3240async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3241 cx.executor().allow_parking();
3242
3243 let fs = FakeFs::new(cx.background_executor.clone());
3244 fs.create_dir(paths::settings_file().parent().unwrap())
3245 .await
3246 .unwrap();
3247 fs.insert_file(
3248 paths::settings_file(),
3249 json!({
3250 "agent": {
3251 "default_profile": "test-profile",
3252 "profiles": {
3253 "test-profile": {
3254 "name": "Test Profile",
3255 "tools": {
3256 EchoTool::name(): true,
3257 DelayTool::name(): true,
3258 WordListTool::name(): true,
3259 ToolRequiringPermission::name(): true,
3260 InfiniteTool::name(): true,
3261 CancellationAwareTool::name(): true,
3262 ThinkingTool::name(): true,
3263 "terminal": true,
3264 }
3265 }
3266 }
3267 }
3268 })
3269 .to_string()
3270 .into_bytes(),
3271 )
3272 .await;
3273
3274 cx.update(|cx| {
3275 settings::init(cx);
3276
3277 match model {
3278 TestModel::Fake => {}
3279 TestModel::Sonnet4 => {
3280 gpui_tokio::init(cx);
3281 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3282 cx.set_http_client(Arc::new(http_client));
3283 let client = Client::production(cx);
3284 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3285 language_model::init(client.clone(), cx);
3286 language_models::init(user_store, client.clone(), cx);
3287 }
3288 };
3289
3290 watch_settings(fs.clone(), cx);
3291 });
3292
3293 let templates = Templates::new();
3294
3295 fs.insert_tree(path!("/test"), json!({})).await;
3296 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3297
3298 let model = cx
3299 .update(|cx| {
3300 if let TestModel::Fake = model {
3301 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3302 } else {
3303 let model_id = model.id();
3304 let models = LanguageModelRegistry::read_global(cx);
3305 let model = models
3306 .available_models(cx)
3307 .find(|model| model.id() == model_id)
3308 .unwrap();
3309
3310 let provider = models.provider(&model.provider_id()).unwrap();
3311 let authenticated = provider.authenticate(cx);
3312
3313 cx.spawn(async move |_cx| {
3314 authenticated.await.unwrap();
3315 model
3316 })
3317 }
3318 })
3319 .await;
3320
3321 let project_context = cx.new(|_cx| ProjectContext::default());
3322 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3323 let context_server_registry =
3324 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3325 let thread = cx.new(|cx| {
3326 Thread::new(
3327 project,
3328 project_context.clone(),
3329 context_server_registry,
3330 templates,
3331 Some(model.clone()),
3332 cx,
3333 )
3334 });
3335 ThreadTest {
3336 model,
3337 thread,
3338 project_context,
3339 context_server_store,
3340 fs,
3341 }
3342}
3343
3344#[cfg(test)]
3345#[ctor::ctor]
3346fn init_logger() {
3347 if std::env::var("RUST_LOG").is_ok() {
3348 env_logger::init();
3349 }
3350}
3351
3352fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3353 let fs = fs.clone();
3354 cx.spawn({
3355 async move |cx| {
3356 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3357 cx.background_executor(),
3358 fs,
3359 paths::settings_file().clone(),
3360 );
3361 let _watcher_task = watcher_task;
3362
3363 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3364 cx.update(|cx| {
3365 SettingsStore::update_global(cx, |settings, cx| {
3366 settings.set_user_settings(&new_settings_content, cx)
3367 })
3368 })
3369 .ok();
3370 }
3371 }
3372 })
3373 .detach();
3374}
3375
3376fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3377 completion
3378 .tools
3379 .iter()
3380 .map(|tool| tool.name.clone())
3381 .collect()
3382}
3383
3384fn setup_context_server(
3385 name: &'static str,
3386 tools: Vec<context_server::types::Tool>,
3387 context_server_store: &Entity<ContextServerStore>,
3388 cx: &mut TestAppContext,
3389) -> mpsc::UnboundedReceiver<(
3390 context_server::types::CallToolParams,
3391 oneshot::Sender<context_server::types::CallToolResponse>,
3392)> {
3393 cx.update(|cx| {
3394 let mut settings = ProjectSettings::get_global(cx).clone();
3395 settings.context_servers.insert(
3396 name.into(),
3397 project::project_settings::ContextServerSettings::Stdio {
3398 enabled: true,
3399 remote: false,
3400 command: ContextServerCommand {
3401 path: "somebinary".into(),
3402 args: Vec::new(),
3403 env: None,
3404 timeout: None,
3405 },
3406 },
3407 );
3408 ProjectSettings::override_global(settings, cx);
3409 });
3410
3411 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3412 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3413 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3414 context_server::types::InitializeResponse {
3415 protocol_version: context_server::types::ProtocolVersion(
3416 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3417 ),
3418 server_info: context_server::types::Implementation {
3419 name: name.into(),
3420 version: "1.0.0".to_string(),
3421 },
3422 capabilities: context_server::types::ServerCapabilities {
3423 tools: Some(context_server::types::ToolsCapabilities {
3424 list_changed: Some(true),
3425 }),
3426 ..Default::default()
3427 },
3428 meta: None,
3429 }
3430 })
3431 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3432 let tools = tools.clone();
3433 async move {
3434 context_server::types::ListToolsResponse {
3435 tools,
3436 next_cursor: None,
3437 meta: None,
3438 }
3439 }
3440 })
3441 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3442 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3443 async move {
3444 let (response_tx, response_rx) = oneshot::channel();
3445 mcp_tool_calls_tx
3446 .unbounded_send((params, response_tx))
3447 .unwrap();
3448 response_rx.await.unwrap()
3449 }
3450 });
3451 context_server_store.update(cx, |store, cx| {
3452 store.start_server(
3453 Arc::new(ContextServer::new(
3454 ContextServerId(name.into()),
3455 Arc::new(fake_transport),
3456 )),
3457 cx,
3458 );
3459 });
3460 cx.run_until_parked();
3461 mcp_tool_calls_rx
3462}
3463
3464#[gpui::test]
3465async fn test_tokens_before_message(cx: &mut TestAppContext) {
3466 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3467 let fake_model = model.as_fake();
3468
3469 // First message
3470 let message_1_id = UserMessageId::new();
3471 thread
3472 .update(cx, |thread, cx| {
3473 thread.send(message_1_id.clone(), ["First message"], cx)
3474 })
3475 .unwrap();
3476 cx.run_until_parked();
3477
3478 // Before any response, tokens_before_message should return None for first message
3479 thread.read_with(cx, |thread, _| {
3480 assert_eq!(
3481 thread.tokens_before_message(&message_1_id),
3482 None,
3483 "First message should have no tokens before it"
3484 );
3485 });
3486
3487 // Complete first message with usage
3488 fake_model.send_last_completion_stream_text_chunk("Response 1");
3489 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3490 language_model::TokenUsage {
3491 input_tokens: 100,
3492 output_tokens: 50,
3493 cache_creation_input_tokens: 0,
3494 cache_read_input_tokens: 0,
3495 },
3496 ));
3497 fake_model.end_last_completion_stream();
3498 cx.run_until_parked();
3499
3500 // First message still has no tokens before it
3501 thread.read_with(cx, |thread, _| {
3502 assert_eq!(
3503 thread.tokens_before_message(&message_1_id),
3504 None,
3505 "First message should still have no tokens before it after response"
3506 );
3507 });
3508
3509 // Second message
3510 let message_2_id = UserMessageId::new();
3511 thread
3512 .update(cx, |thread, cx| {
3513 thread.send(message_2_id.clone(), ["Second message"], cx)
3514 })
3515 .unwrap();
3516 cx.run_until_parked();
3517
3518 // Second message should have first message's input tokens before it
3519 thread.read_with(cx, |thread, _| {
3520 assert_eq!(
3521 thread.tokens_before_message(&message_2_id),
3522 Some(100),
3523 "Second message should have 100 tokens before it (from first request)"
3524 );
3525 });
3526
3527 // Complete second message
3528 fake_model.send_last_completion_stream_text_chunk("Response 2");
3529 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3530 language_model::TokenUsage {
3531 input_tokens: 250, // Total for this request (includes previous context)
3532 output_tokens: 75,
3533 cache_creation_input_tokens: 0,
3534 cache_read_input_tokens: 0,
3535 },
3536 ));
3537 fake_model.end_last_completion_stream();
3538 cx.run_until_parked();
3539
3540 // Third message
3541 let message_3_id = UserMessageId::new();
3542 thread
3543 .update(cx, |thread, cx| {
3544 thread.send(message_3_id.clone(), ["Third message"], cx)
3545 })
3546 .unwrap();
3547 cx.run_until_parked();
3548
3549 // Third message should have second message's input tokens (250) before it
3550 thread.read_with(cx, |thread, _| {
3551 assert_eq!(
3552 thread.tokens_before_message(&message_3_id),
3553 Some(250),
3554 "Third message should have 250 tokens before it (from second request)"
3555 );
3556 // Second message should still have 100
3557 assert_eq!(
3558 thread.tokens_before_message(&message_2_id),
3559 Some(100),
3560 "Second message should still have 100 tokens before it"
3561 );
3562 // First message still has none
3563 assert_eq!(
3564 thread.tokens_before_message(&message_1_id),
3565 None,
3566 "First message should still have no tokens before it"
3567 );
3568 });
3569}
3570
3571#[gpui::test]
3572async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3573 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3574 let fake_model = model.as_fake();
3575
3576 // Set up three messages with responses
3577 let message_1_id = UserMessageId::new();
3578 thread
3579 .update(cx, |thread, cx| {
3580 thread.send(message_1_id.clone(), ["Message 1"], cx)
3581 })
3582 .unwrap();
3583 cx.run_until_parked();
3584 fake_model.send_last_completion_stream_text_chunk("Response 1");
3585 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3586 language_model::TokenUsage {
3587 input_tokens: 100,
3588 output_tokens: 50,
3589 cache_creation_input_tokens: 0,
3590 cache_read_input_tokens: 0,
3591 },
3592 ));
3593 fake_model.end_last_completion_stream();
3594 cx.run_until_parked();
3595
3596 let message_2_id = UserMessageId::new();
3597 thread
3598 .update(cx, |thread, cx| {
3599 thread.send(message_2_id.clone(), ["Message 2"], cx)
3600 })
3601 .unwrap();
3602 cx.run_until_parked();
3603 fake_model.send_last_completion_stream_text_chunk("Response 2");
3604 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3605 language_model::TokenUsage {
3606 input_tokens: 250,
3607 output_tokens: 75,
3608 cache_creation_input_tokens: 0,
3609 cache_read_input_tokens: 0,
3610 },
3611 ));
3612 fake_model.end_last_completion_stream();
3613 cx.run_until_parked();
3614
3615 // Verify initial state
3616 thread.read_with(cx, |thread, _| {
3617 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3618 });
3619
3620 // Truncate at message 2 (removes message 2 and everything after)
3621 thread
3622 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3623 .unwrap();
3624 cx.run_until_parked();
3625
3626 // After truncation, message_2_id no longer exists, so lookup should return None
3627 thread.read_with(cx, |thread, _| {
3628 assert_eq!(
3629 thread.tokens_before_message(&message_2_id),
3630 None,
3631 "After truncation, message 2 no longer exists"
3632 );
3633 // Message 1 still exists but has no tokens before it
3634 assert_eq!(
3635 thread.tokens_before_message(&message_1_id),
3636 None,
3637 "First message still has no tokens before it"
3638 );
3639 });
3640}
3641
3642#[gpui::test]
3643async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3644 init_test(cx);
3645
3646 let fs = FakeFs::new(cx.executor());
3647 fs.insert_tree("/root", json!({})).await;
3648 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3649
3650 // Test 1: Deny rule blocks command
3651 {
3652 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3653 let environment = Rc::new(FakeThreadEnvironment {
3654 handle: handle.clone(),
3655 });
3656
3657 cx.update(|cx| {
3658 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3659 settings.tool_permissions.tools.insert(
3660 "terminal".into(),
3661 agent_settings::ToolRules {
3662 default_mode: settings::ToolPermissionMode::Confirm,
3663 always_allow: vec![],
3664 always_deny: vec![
3665 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3666 ],
3667 always_confirm: vec![],
3668 invalid_patterns: vec![],
3669 },
3670 );
3671 agent_settings::AgentSettings::override_global(settings, cx);
3672 });
3673
3674 #[allow(clippy::arc_with_non_send_sync)]
3675 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3676 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3677
3678 let task = cx.update(|cx| {
3679 tool.run(
3680 crate::TerminalToolInput {
3681 command: "rm -rf /".to_string(),
3682 cd: ".".to_string(),
3683 timeout_ms: None,
3684 },
3685 event_stream,
3686 cx,
3687 )
3688 });
3689
3690 let result = task.await;
3691 assert!(
3692 result.is_err(),
3693 "expected command to be blocked by deny rule"
3694 );
3695 let err_msg = result.unwrap_err().to_string().to_lowercase();
3696 assert!(
3697 err_msg.contains("blocked"),
3698 "error should mention the command was blocked"
3699 );
3700 }
3701
3702 // Test 2: Allow rule skips confirmation (and overrides default_mode: Deny)
3703 {
3704 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3705 let environment = Rc::new(FakeThreadEnvironment {
3706 handle: handle.clone(),
3707 });
3708
3709 cx.update(|cx| {
3710 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3711 settings.always_allow_tool_actions = false;
3712 settings.tool_permissions.tools.insert(
3713 "terminal".into(),
3714 agent_settings::ToolRules {
3715 default_mode: settings::ToolPermissionMode::Deny,
3716 always_allow: vec![
3717 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
3718 ],
3719 always_deny: vec![],
3720 always_confirm: vec![],
3721 invalid_patterns: vec![],
3722 },
3723 );
3724 agent_settings::AgentSettings::override_global(settings, cx);
3725 });
3726
3727 #[allow(clippy::arc_with_non_send_sync)]
3728 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3729 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3730
3731 let task = cx.update(|cx| {
3732 tool.run(
3733 crate::TerminalToolInput {
3734 command: "echo hello".to_string(),
3735 cd: ".".to_string(),
3736 timeout_ms: None,
3737 },
3738 event_stream,
3739 cx,
3740 )
3741 });
3742
3743 let update = rx.expect_update_fields().await;
3744 assert!(
3745 update.content.iter().any(|blocks| {
3746 blocks
3747 .iter()
3748 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
3749 }),
3750 "expected terminal content (allow rule should skip confirmation and override default deny)"
3751 );
3752
3753 let result = task.await;
3754 assert!(
3755 result.is_ok(),
3756 "expected command to succeed without confirmation"
3757 );
3758 }
3759
3760 // Test 3: always_allow_tool_actions=true overrides always_confirm patterns
3761 {
3762 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3763 let environment = Rc::new(FakeThreadEnvironment {
3764 handle: handle.clone(),
3765 });
3766
3767 cx.update(|cx| {
3768 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3769 settings.always_allow_tool_actions = true;
3770 settings.tool_permissions.tools.insert(
3771 "terminal".into(),
3772 agent_settings::ToolRules {
3773 default_mode: settings::ToolPermissionMode::Allow,
3774 always_allow: vec![],
3775 always_deny: vec![],
3776 always_confirm: vec![
3777 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
3778 ],
3779 invalid_patterns: vec![],
3780 },
3781 );
3782 agent_settings::AgentSettings::override_global(settings, cx);
3783 });
3784
3785 #[allow(clippy::arc_with_non_send_sync)]
3786 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3787 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3788
3789 let task = cx.update(|cx| {
3790 tool.run(
3791 crate::TerminalToolInput {
3792 command: "sudo rm file".to_string(),
3793 cd: ".".to_string(),
3794 timeout_ms: None,
3795 },
3796 event_stream,
3797 cx,
3798 )
3799 });
3800
3801 // With always_allow_tool_actions=true, confirm patterns are overridden
3802 task.await
3803 .expect("command should be allowed with always_allow_tool_actions=true");
3804 }
3805
3806 // Test 4: always_allow_tool_actions=true overrides default_mode: Deny
3807 {
3808 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3809 let environment = Rc::new(FakeThreadEnvironment {
3810 handle: handle.clone(),
3811 });
3812
3813 cx.update(|cx| {
3814 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3815 settings.always_allow_tool_actions = true;
3816 settings.tool_permissions.tools.insert(
3817 "terminal".into(),
3818 agent_settings::ToolRules {
3819 default_mode: settings::ToolPermissionMode::Deny,
3820 always_allow: vec![],
3821 always_deny: vec![],
3822 always_confirm: vec![],
3823 invalid_patterns: vec![],
3824 },
3825 );
3826 agent_settings::AgentSettings::override_global(settings, cx);
3827 });
3828
3829 #[allow(clippy::arc_with_non_send_sync)]
3830 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3831 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3832
3833 let task = cx.update(|cx| {
3834 tool.run(
3835 crate::TerminalToolInput {
3836 command: "echo hello".to_string(),
3837 cd: ".".to_string(),
3838 timeout_ms: None,
3839 },
3840 event_stream,
3841 cx,
3842 )
3843 });
3844
3845 // With always_allow_tool_actions=true, even default_mode: Deny is overridden
3846 task.await
3847 .expect("command should be allowed with always_allow_tool_actions=true");
3848 }
3849}
3850
3851#[gpui::test]
3852async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
3853 init_test(cx);
3854
3855 cx.update(|cx| {
3856 cx.update_flags(true, vec!["subagents".to_string()]);
3857 });
3858
3859 let fs = FakeFs::new(cx.executor());
3860 fs.insert_tree(path!("/test"), json!({})).await;
3861 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3862 let project_context = cx.new(|_cx| ProjectContext::default());
3863 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3864 let context_server_registry =
3865 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3866 let model = Arc::new(FakeLanguageModel::default());
3867
3868 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3869 let environment = Rc::new(FakeThreadEnvironment { handle });
3870
3871 let thread = cx.new(|cx| {
3872 let mut thread = Thread::new(
3873 project.clone(),
3874 project_context,
3875 context_server_registry,
3876 Templates::new(),
3877 Some(model),
3878 cx,
3879 );
3880 thread.add_default_tools(environment, cx);
3881 thread
3882 });
3883
3884 thread.read_with(cx, |thread, _| {
3885 assert!(
3886 thread.has_registered_tool("subagent"),
3887 "subagent tool should be present when feature flag is enabled"
3888 );
3889 });
3890}
3891
3892#[gpui::test]
3893async fn test_subagent_thread_inherits_parent_model(cx: &mut TestAppContext) {
3894 init_test(cx);
3895
3896 cx.update(|cx| {
3897 cx.update_flags(true, vec!["subagents".to_string()]);
3898 });
3899
3900 let fs = FakeFs::new(cx.executor());
3901 fs.insert_tree(path!("/test"), json!({})).await;
3902 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3903 let project_context = cx.new(|_cx| ProjectContext::default());
3904 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3905 let context_server_registry =
3906 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3907 let model = Arc::new(FakeLanguageModel::default());
3908
3909 let subagent_context = SubagentContext {
3910 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3911 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3912 depth: 1,
3913 summary_prompt: "Summarize".to_string(),
3914 context_low_prompt: "Context low".to_string(),
3915 };
3916
3917 let subagent = cx.new(|cx| {
3918 Thread::new_subagent(
3919 project.clone(),
3920 project_context,
3921 context_server_registry,
3922 Templates::new(),
3923 model.clone(),
3924 subagent_context,
3925 std::collections::BTreeMap::new(),
3926 cx,
3927 )
3928 });
3929
3930 subagent.read_with(cx, |thread, _| {
3931 assert!(thread.is_subagent());
3932 assert_eq!(thread.depth(), 1);
3933 assert!(thread.model().is_some());
3934 });
3935}
3936
3937#[gpui::test]
3938async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
3939 init_test(cx);
3940
3941 cx.update(|cx| {
3942 cx.update_flags(true, vec!["subagents".to_string()]);
3943 });
3944
3945 let fs = FakeFs::new(cx.executor());
3946 fs.insert_tree(path!("/test"), json!({})).await;
3947 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3948 let project_context = cx.new(|_cx| ProjectContext::default());
3949 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3950 let context_server_registry =
3951 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3952 let model = Arc::new(FakeLanguageModel::default());
3953
3954 let subagent_context = SubagentContext {
3955 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3956 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3957 depth: MAX_SUBAGENT_DEPTH,
3958 summary_prompt: "Summarize".to_string(),
3959 context_low_prompt: "Context low".to_string(),
3960 };
3961
3962 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3963 let environment = Rc::new(FakeThreadEnvironment { handle });
3964
3965 let deep_subagent = cx.new(|cx| {
3966 let mut thread = Thread::new_subagent(
3967 project.clone(),
3968 project_context,
3969 context_server_registry,
3970 Templates::new(),
3971 model.clone(),
3972 subagent_context,
3973 std::collections::BTreeMap::new(),
3974 cx,
3975 );
3976 thread.add_default_tools(environment, cx);
3977 thread
3978 });
3979
3980 deep_subagent.read_with(cx, |thread, _| {
3981 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
3982 assert!(
3983 !thread.has_registered_tool("subagent"),
3984 "subagent tool should not be present at max depth"
3985 );
3986 });
3987}
3988
3989#[gpui::test]
3990async fn test_subagent_receives_task_prompt(cx: &mut TestAppContext) {
3991 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3992 let fake_model = model.as_fake();
3993
3994 cx.update(|cx| {
3995 cx.update_flags(true, vec!["subagents".to_string()]);
3996 });
3997
3998 let subagent_context = SubagentContext {
3999 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4000 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4001 depth: 1,
4002 summary_prompt: "Summarize your work".to_string(),
4003 context_low_prompt: "Context low, wrap up".to_string(),
4004 };
4005
4006 let project = thread.read_with(cx, |t, _| t.project.clone());
4007 let project_context = cx.new(|_cx| ProjectContext::default());
4008 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4009 let context_server_registry =
4010 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4011
4012 let subagent = cx.new(|cx| {
4013 Thread::new_subagent(
4014 project.clone(),
4015 project_context,
4016 context_server_registry,
4017 Templates::new(),
4018 model.clone(),
4019 subagent_context,
4020 std::collections::BTreeMap::new(),
4021 cx,
4022 )
4023 });
4024
4025 let task_prompt = "Find all TODO comments in the codebase";
4026 subagent
4027 .update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))
4028 .unwrap();
4029 cx.run_until_parked();
4030
4031 let pending = fake_model.pending_completions();
4032 assert_eq!(pending.len(), 1, "should have one pending completion");
4033
4034 let messages = &pending[0].messages;
4035 let user_messages: Vec<_> = messages
4036 .iter()
4037 .filter(|m| m.role == language_model::Role::User)
4038 .collect();
4039 assert_eq!(user_messages.len(), 1, "should have one user message");
4040
4041 let content = &user_messages[0].content[0];
4042 assert!(
4043 content.to_str().unwrap().contains("TODO"),
4044 "task prompt should be in user message"
4045 );
4046}
4047
4048#[gpui::test]
4049async fn test_subagent_returns_summary_on_completion(cx: &mut TestAppContext) {
4050 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4051 let fake_model = model.as_fake();
4052
4053 cx.update(|cx| {
4054 cx.update_flags(true, vec!["subagents".to_string()]);
4055 });
4056
4057 let subagent_context = SubagentContext {
4058 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4059 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4060 depth: 1,
4061 summary_prompt: "Please summarize what you found".to_string(),
4062 context_low_prompt: "Context low, wrap up".to_string(),
4063 };
4064
4065 let project = thread.read_with(cx, |t, _| t.project.clone());
4066 let project_context = cx.new(|_cx| ProjectContext::default());
4067 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4068 let context_server_registry =
4069 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4070
4071 let subagent = cx.new(|cx| {
4072 Thread::new_subagent(
4073 project.clone(),
4074 project_context,
4075 context_server_registry,
4076 Templates::new(),
4077 model.clone(),
4078 subagent_context,
4079 std::collections::BTreeMap::new(),
4080 cx,
4081 )
4082 });
4083
4084 subagent
4085 .update(cx, |thread, cx| {
4086 thread.submit_user_message("Do some work", cx)
4087 })
4088 .unwrap();
4089 cx.run_until_parked();
4090
4091 fake_model.send_last_completion_stream_text_chunk("I did the work");
4092 fake_model
4093 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4094 fake_model.end_last_completion_stream();
4095 cx.run_until_parked();
4096
4097 subagent
4098 .update(cx, |thread, cx| thread.request_final_summary(cx))
4099 .unwrap();
4100 cx.run_until_parked();
4101
4102 let pending = fake_model.pending_completions();
4103 assert!(
4104 !pending.is_empty(),
4105 "should have pending completion for summary"
4106 );
4107
4108 let messages = &pending.last().unwrap().messages;
4109 let user_messages: Vec<_> = messages
4110 .iter()
4111 .filter(|m| m.role == language_model::Role::User)
4112 .collect();
4113
4114 let last_user = user_messages.last().unwrap();
4115 assert!(
4116 last_user.content[0].to_str().unwrap().contains("summarize"),
4117 "summary prompt should be sent"
4118 );
4119}
4120
4121#[gpui::test]
4122async fn test_allowed_tools_restricts_subagent_capabilities(cx: &mut TestAppContext) {
4123 init_test(cx);
4124
4125 cx.update(|cx| {
4126 cx.update_flags(true, vec!["subagents".to_string()]);
4127 });
4128
4129 let fs = FakeFs::new(cx.executor());
4130 fs.insert_tree(path!("/test"), json!({})).await;
4131 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4132 let project_context = cx.new(|_cx| ProjectContext::default());
4133 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4134 let context_server_registry =
4135 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4136 let model = Arc::new(FakeLanguageModel::default());
4137
4138 let subagent_context = SubagentContext {
4139 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4140 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4141 depth: 1,
4142 summary_prompt: "Summarize".to_string(),
4143 context_low_prompt: "Context low".to_string(),
4144 };
4145
4146 let subagent = cx.new(|cx| {
4147 let mut thread = Thread::new_subagent(
4148 project.clone(),
4149 project_context,
4150 context_server_registry,
4151 Templates::new(),
4152 model.clone(),
4153 subagent_context,
4154 std::collections::BTreeMap::new(),
4155 cx,
4156 );
4157 thread.add_tool(EchoTool);
4158 thread.add_tool(DelayTool);
4159 thread.add_tool(WordListTool);
4160 thread
4161 });
4162
4163 subagent.read_with(cx, |thread, _| {
4164 assert!(thread.has_registered_tool("echo"));
4165 assert!(thread.has_registered_tool("delay"));
4166 assert!(thread.has_registered_tool("word_list"));
4167 });
4168
4169 let allowed: collections::HashSet<gpui::SharedString> =
4170 vec!["echo".into()].into_iter().collect();
4171
4172 subagent.update(cx, |thread, _cx| {
4173 thread.restrict_tools(&allowed);
4174 });
4175
4176 subagent.read_with(cx, |thread, _| {
4177 assert!(
4178 thread.has_registered_tool("echo"),
4179 "echo should still be available"
4180 );
4181 assert!(
4182 !thread.has_registered_tool("delay"),
4183 "delay should be removed"
4184 );
4185 assert!(
4186 !thread.has_registered_tool("word_list"),
4187 "word_list should be removed"
4188 );
4189 });
4190}
4191
4192#[gpui::test]
4193async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4194 init_test(cx);
4195
4196 cx.update(|cx| {
4197 cx.update_flags(true, vec!["subagents".to_string()]);
4198 });
4199
4200 let fs = FakeFs::new(cx.executor());
4201 fs.insert_tree(path!("/test"), json!({})).await;
4202 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4203 let project_context = cx.new(|_cx| ProjectContext::default());
4204 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4205 let context_server_registry =
4206 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4207 let model = Arc::new(FakeLanguageModel::default());
4208
4209 let parent = cx.new(|cx| {
4210 Thread::new(
4211 project.clone(),
4212 project_context.clone(),
4213 context_server_registry.clone(),
4214 Templates::new(),
4215 Some(model.clone()),
4216 cx,
4217 )
4218 });
4219
4220 let subagent_context = SubagentContext {
4221 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4222 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4223 depth: 1,
4224 summary_prompt: "Summarize".to_string(),
4225 context_low_prompt: "Context low".to_string(),
4226 };
4227
4228 let subagent = cx.new(|cx| {
4229 Thread::new_subagent(
4230 project.clone(),
4231 project_context.clone(),
4232 context_server_registry.clone(),
4233 Templates::new(),
4234 model.clone(),
4235 subagent_context,
4236 std::collections::BTreeMap::new(),
4237 cx,
4238 )
4239 });
4240
4241 parent.update(cx, |thread, _cx| {
4242 thread.register_running_subagent(subagent.downgrade());
4243 });
4244
4245 subagent
4246 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4247 .unwrap();
4248 cx.run_until_parked();
4249
4250 subagent.read_with(cx, |thread, _| {
4251 assert!(!thread.is_turn_complete(), "subagent should be running");
4252 });
4253
4254 parent.update(cx, |thread, cx| {
4255 thread.cancel(cx).detach();
4256 });
4257
4258 subagent.read_with(cx, |thread, _| {
4259 assert!(
4260 thread.is_turn_complete(),
4261 "subagent should be cancelled when parent cancels"
4262 );
4263 });
4264}
4265
4266#[gpui::test]
4267async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4268 // This test verifies that the subagent tool properly handles user cancellation
4269 // via `event_stream.cancelled_by_user()` and stops all running subagents.
4270 init_test(cx);
4271 always_allow_tools(cx);
4272
4273 cx.update(|cx| {
4274 cx.update_flags(true, vec!["subagents".to_string()]);
4275 });
4276
4277 let fs = FakeFs::new(cx.executor());
4278 fs.insert_tree(path!("/test"), json!({})).await;
4279 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4280 let project_context = cx.new(|_cx| ProjectContext::default());
4281 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4282 let context_server_registry =
4283 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4284 let model = Arc::new(FakeLanguageModel::default());
4285
4286 let parent = cx.new(|cx| {
4287 Thread::new(
4288 project.clone(),
4289 project_context.clone(),
4290 context_server_registry.clone(),
4291 Templates::new(),
4292 Some(model.clone()),
4293 cx,
4294 )
4295 });
4296
4297 #[allow(clippy::arc_with_non_send_sync)]
4298 let tool = Arc::new(SubagentTool::new(parent.downgrade(), 0));
4299
4300 let (event_stream, _rx, mut cancellation_tx) =
4301 crate::ToolCallEventStream::test_with_cancellation();
4302
4303 // Start the subagent tool
4304 let task = cx.update(|cx| {
4305 tool.run(
4306 SubagentToolInput {
4307 label: "Long running task".to_string(),
4308 task_prompt: "Do a very long task that takes forever".to_string(),
4309 summary_prompt: "Summarize".to_string(),
4310 context_low_prompt: "Context low".to_string(),
4311 timeout_ms: None,
4312 allowed_tools: None,
4313 },
4314 event_stream.clone(),
4315 cx,
4316 )
4317 });
4318
4319 cx.run_until_parked();
4320
4321 // Signal cancellation via the event stream
4322 crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4323
4324 // The task should complete promptly with a cancellation error
4325 let timeout = cx.background_executor.timer(Duration::from_secs(5));
4326 let result = futures::select! {
4327 result = task.fuse() => result,
4328 _ = timeout.fuse() => {
4329 panic!("subagent tool did not respond to cancellation within timeout");
4330 }
4331 };
4332
4333 // Verify we got a cancellation error
4334 let err = result.unwrap_err();
4335 assert!(
4336 err.to_string().contains("cancelled by user"),
4337 "expected cancellation error, got: {}",
4338 err
4339 );
4340}
4341
4342#[gpui::test]
4343async fn test_subagent_model_error_returned_as_tool_error(cx: &mut TestAppContext) {
4344 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4345 let fake_model = model.as_fake();
4346
4347 cx.update(|cx| {
4348 cx.update_flags(true, vec!["subagents".to_string()]);
4349 });
4350
4351 let subagent_context = SubagentContext {
4352 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4353 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4354 depth: 1,
4355 summary_prompt: "Summarize".to_string(),
4356 context_low_prompt: "Context low".to_string(),
4357 };
4358
4359 let project = thread.read_with(cx, |t, _| t.project.clone());
4360 let project_context = cx.new(|_cx| ProjectContext::default());
4361 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4362 let context_server_registry =
4363 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4364
4365 let subagent = cx.new(|cx| {
4366 Thread::new_subagent(
4367 project.clone(),
4368 project_context,
4369 context_server_registry,
4370 Templates::new(),
4371 model.clone(),
4372 subagent_context,
4373 std::collections::BTreeMap::new(),
4374 cx,
4375 )
4376 });
4377
4378 subagent
4379 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4380 .unwrap();
4381 cx.run_until_parked();
4382
4383 subagent.read_with(cx, |thread, _| {
4384 assert!(!thread.is_turn_complete(), "turn should be in progress");
4385 });
4386
4387 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::NoApiKey {
4388 provider: LanguageModelProviderName::from("Fake".to_string()),
4389 });
4390 fake_model.end_last_completion_stream();
4391 cx.run_until_parked();
4392
4393 subagent.read_with(cx, |thread, _| {
4394 assert!(
4395 thread.is_turn_complete(),
4396 "turn should be complete after non-retryable error"
4397 );
4398 });
4399}
4400
4401#[gpui::test]
4402async fn test_subagent_timeout_triggers_early_summary(cx: &mut TestAppContext) {
4403 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4404 let fake_model = model.as_fake();
4405
4406 cx.update(|cx| {
4407 cx.update_flags(true, vec!["subagents".to_string()]);
4408 });
4409
4410 let subagent_context = SubagentContext {
4411 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4412 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4413 depth: 1,
4414 summary_prompt: "Summarize your work".to_string(),
4415 context_low_prompt: "Context low, stop and summarize".to_string(),
4416 };
4417
4418 let project = thread.read_with(cx, |t, _| t.project.clone());
4419 let project_context = cx.new(|_cx| ProjectContext::default());
4420 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4421 let context_server_registry =
4422 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4423
4424 let subagent = cx.new(|cx| {
4425 Thread::new_subagent(
4426 project.clone(),
4427 project_context.clone(),
4428 context_server_registry.clone(),
4429 Templates::new(),
4430 model.clone(),
4431 subagent_context.clone(),
4432 std::collections::BTreeMap::new(),
4433 cx,
4434 )
4435 });
4436
4437 subagent.update(cx, |thread, _| {
4438 thread.add_tool(EchoTool);
4439 });
4440
4441 subagent
4442 .update(cx, |thread, cx| {
4443 thread.submit_user_message("Do some work", cx)
4444 })
4445 .unwrap();
4446 cx.run_until_parked();
4447
4448 fake_model.send_last_completion_stream_text_chunk("Working on it...");
4449 fake_model
4450 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4451 fake_model.end_last_completion_stream();
4452 cx.run_until_parked();
4453
4454 let interrupt_result = subagent.update(cx, |thread, cx| thread.interrupt_for_summary(cx));
4455 assert!(
4456 interrupt_result.is_ok(),
4457 "interrupt_for_summary should succeed"
4458 );
4459
4460 cx.run_until_parked();
4461
4462 let pending = fake_model.pending_completions();
4463 assert!(
4464 !pending.is_empty(),
4465 "should have pending completion for interrupted summary"
4466 );
4467
4468 let messages = &pending.last().unwrap().messages;
4469 let user_messages: Vec<_> = messages
4470 .iter()
4471 .filter(|m| m.role == language_model::Role::User)
4472 .collect();
4473
4474 let last_user = user_messages.last().unwrap();
4475 let content_str = last_user.content[0].to_str().unwrap();
4476 assert!(
4477 content_str.contains("Context low") || content_str.contains("stop and summarize"),
4478 "context_low_prompt should be sent when interrupting: got {:?}",
4479 content_str
4480 );
4481}
4482
4483#[gpui::test]
4484async fn test_context_low_check_returns_true_when_usage_high(cx: &mut TestAppContext) {
4485 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4486 let fake_model = model.as_fake();
4487
4488 cx.update(|cx| {
4489 cx.update_flags(true, vec!["subagents".to_string()]);
4490 });
4491
4492 let subagent_context = SubagentContext {
4493 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4494 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4495 depth: 1,
4496 summary_prompt: "Summarize".to_string(),
4497 context_low_prompt: "Context low".to_string(),
4498 };
4499
4500 let project = thread.read_with(cx, |t, _| t.project.clone());
4501 let project_context = cx.new(|_cx| ProjectContext::default());
4502 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4503 let context_server_registry =
4504 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4505
4506 let subagent = cx.new(|cx| {
4507 Thread::new_subagent(
4508 project.clone(),
4509 project_context,
4510 context_server_registry,
4511 Templates::new(),
4512 model.clone(),
4513 subagent_context,
4514 std::collections::BTreeMap::new(),
4515 cx,
4516 )
4517 });
4518
4519 subagent
4520 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4521 .unwrap();
4522 cx.run_until_parked();
4523
4524 let max_tokens = model.max_token_count();
4525 let high_usage = language_model::TokenUsage {
4526 input_tokens: (max_tokens as f64 * 0.80) as u64,
4527 output_tokens: 0,
4528 cache_creation_input_tokens: 0,
4529 cache_read_input_tokens: 0,
4530 };
4531
4532 fake_model
4533 .send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(high_usage));
4534 fake_model.send_last_completion_stream_text_chunk("Working...");
4535 fake_model
4536 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4537 fake_model.end_last_completion_stream();
4538 cx.run_until_parked();
4539
4540 let usage = subagent.read_with(cx, |thread, _| thread.latest_token_usage());
4541 assert!(usage.is_some(), "should have token usage after completion");
4542
4543 let usage = usage.unwrap();
4544 let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
4545 assert!(
4546 remaining_ratio <= 0.25,
4547 "remaining ratio should be at or below 25% (got {}%), indicating context is low",
4548 remaining_ratio * 100.0
4549 );
4550}
4551
4552#[gpui::test]
4553async fn test_allowed_tools_rejects_unknown_tool(cx: &mut TestAppContext) {
4554 init_test(cx);
4555
4556 cx.update(|cx| {
4557 cx.update_flags(true, vec!["subagents".to_string()]);
4558 });
4559
4560 let fs = FakeFs::new(cx.executor());
4561 fs.insert_tree(path!("/test"), json!({})).await;
4562 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4563 let project_context = cx.new(|_cx| ProjectContext::default());
4564 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4565 let context_server_registry =
4566 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4567 let model = Arc::new(FakeLanguageModel::default());
4568
4569 let parent = cx.new(|cx| {
4570 let mut thread = Thread::new(
4571 project.clone(),
4572 project_context.clone(),
4573 context_server_registry.clone(),
4574 Templates::new(),
4575 Some(model.clone()),
4576 cx,
4577 );
4578 thread.add_tool(EchoTool);
4579 thread
4580 });
4581
4582 #[allow(clippy::arc_with_non_send_sync)]
4583 let tool = Arc::new(SubagentTool::new(parent.downgrade(), 0));
4584
4585 let allowed_tools = Some(vec!["nonexistent_tool".to_string()]);
4586 let result = cx.read(|cx| tool.validate_allowed_tools(&allowed_tools, cx));
4587
4588 assert!(result.is_err(), "should reject unknown tool");
4589 let err_msg = result.unwrap_err().to_string();
4590 assert!(
4591 err_msg.contains("nonexistent_tool"),
4592 "error should mention the invalid tool name: {}",
4593 err_msg
4594 );
4595 assert!(
4596 err_msg.contains("do not exist"),
4597 "error should explain the tool does not exist: {}",
4598 err_msg
4599 );
4600}
4601
4602#[gpui::test]
4603async fn test_subagent_empty_response_handled(cx: &mut TestAppContext) {
4604 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4605 let fake_model = model.as_fake();
4606
4607 cx.update(|cx| {
4608 cx.update_flags(true, vec!["subagents".to_string()]);
4609 });
4610
4611 let subagent_context = SubagentContext {
4612 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4613 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4614 depth: 1,
4615 summary_prompt: "Summarize".to_string(),
4616 context_low_prompt: "Context low".to_string(),
4617 };
4618
4619 let project = thread.read_with(cx, |t, _| t.project.clone());
4620 let project_context = cx.new(|_cx| ProjectContext::default());
4621 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4622 let context_server_registry =
4623 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4624
4625 let subagent = cx.new(|cx| {
4626 Thread::new_subagent(
4627 project.clone(),
4628 project_context,
4629 context_server_registry,
4630 Templates::new(),
4631 model.clone(),
4632 subagent_context,
4633 std::collections::BTreeMap::new(),
4634 cx,
4635 )
4636 });
4637
4638 subagent
4639 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4640 .unwrap();
4641 cx.run_until_parked();
4642
4643 fake_model
4644 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4645 fake_model.end_last_completion_stream();
4646 cx.run_until_parked();
4647
4648 subagent.read_with(cx, |thread, _| {
4649 assert!(
4650 thread.is_turn_complete(),
4651 "turn should complete even with empty response"
4652 );
4653 });
4654}
4655
4656#[gpui::test]
4657async fn test_nested_subagent_at_depth_2_succeeds(cx: &mut TestAppContext) {
4658 init_test(cx);
4659
4660 cx.update(|cx| {
4661 cx.update_flags(true, vec!["subagents".to_string()]);
4662 });
4663
4664 let fs = FakeFs::new(cx.executor());
4665 fs.insert_tree(path!("/test"), json!({})).await;
4666 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4667 let project_context = cx.new(|_cx| ProjectContext::default());
4668 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4669 let context_server_registry =
4670 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4671 let model = Arc::new(FakeLanguageModel::default());
4672
4673 let depth_1_context = SubagentContext {
4674 parent_thread_id: agent_client_protocol::SessionId::new("root-id"),
4675 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-1"),
4676 depth: 1,
4677 summary_prompt: "Summarize".to_string(),
4678 context_low_prompt: "Context low".to_string(),
4679 };
4680
4681 let depth_1_subagent = cx.new(|cx| {
4682 Thread::new_subagent(
4683 project.clone(),
4684 project_context.clone(),
4685 context_server_registry.clone(),
4686 Templates::new(),
4687 model.clone(),
4688 depth_1_context,
4689 std::collections::BTreeMap::new(),
4690 cx,
4691 )
4692 });
4693
4694 depth_1_subagent.read_with(cx, |thread, _| {
4695 assert_eq!(thread.depth(), 1);
4696 assert!(thread.is_subagent());
4697 });
4698
4699 let depth_2_context = SubagentContext {
4700 parent_thread_id: agent_client_protocol::SessionId::new("depth-1-id"),
4701 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-2"),
4702 depth: 2,
4703 summary_prompt: "Summarize depth 2".to_string(),
4704 context_low_prompt: "Context low depth 2".to_string(),
4705 };
4706
4707 let depth_2_subagent = cx.new(|cx| {
4708 Thread::new_subagent(
4709 project.clone(),
4710 project_context.clone(),
4711 context_server_registry.clone(),
4712 Templates::new(),
4713 model.clone(),
4714 depth_2_context,
4715 std::collections::BTreeMap::new(),
4716 cx,
4717 )
4718 });
4719
4720 depth_2_subagent.read_with(cx, |thread, _| {
4721 assert_eq!(thread.depth(), 2);
4722 assert!(thread.is_subagent());
4723 });
4724
4725 depth_2_subagent
4726 .update(cx, |thread, cx| {
4727 thread.submit_user_message("Nested task", cx)
4728 })
4729 .unwrap();
4730 cx.run_until_parked();
4731
4732 let pending = model.as_fake().pending_completions();
4733 assert!(
4734 !pending.is_empty(),
4735 "depth-2 subagent should be able to submit messages"
4736 );
4737}
4738
4739#[gpui::test]
4740async fn test_subagent_uses_tool_and_returns_result(cx: &mut TestAppContext) {
4741 init_test(cx);
4742 always_allow_tools(cx);
4743
4744 cx.update(|cx| {
4745 cx.update_flags(true, vec!["subagents".to_string()]);
4746 });
4747
4748 let fs = FakeFs::new(cx.executor());
4749 fs.insert_tree(path!("/test"), json!({})).await;
4750 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4751 let project_context = cx.new(|_cx| ProjectContext::default());
4752 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4753 let context_server_registry =
4754 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4755 let model = Arc::new(FakeLanguageModel::default());
4756 let fake_model = model.as_fake();
4757
4758 let subagent_context = SubagentContext {
4759 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4760 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4761 depth: 1,
4762 summary_prompt: "Summarize what you did".to_string(),
4763 context_low_prompt: "Context low".to_string(),
4764 };
4765
4766 let subagent = cx.new(|cx| {
4767 let mut thread = Thread::new_subagent(
4768 project.clone(),
4769 project_context,
4770 context_server_registry,
4771 Templates::new(),
4772 model.clone(),
4773 subagent_context,
4774 std::collections::BTreeMap::new(),
4775 cx,
4776 );
4777 thread.add_tool(EchoTool);
4778 thread
4779 });
4780
4781 subagent.read_with(cx, |thread, _| {
4782 assert!(
4783 thread.has_registered_tool("echo"),
4784 "subagent should have echo tool"
4785 );
4786 });
4787
4788 subagent
4789 .update(cx, |thread, cx| {
4790 thread.submit_user_message("Use the echo tool to echo 'hello world'", cx)
4791 })
4792 .unwrap();
4793 cx.run_until_parked();
4794
4795 let tool_use = LanguageModelToolUse {
4796 id: "tool_call_1".into(),
4797 name: EchoTool::name().into(),
4798 raw_input: json!({"text": "hello world"}).to_string(),
4799 input: json!({"text": "hello world"}),
4800 is_input_complete: true,
4801 thought_signature: None,
4802 };
4803 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
4804 fake_model.end_last_completion_stream();
4805 cx.run_until_parked();
4806
4807 let pending = fake_model.pending_completions();
4808 assert!(
4809 !pending.is_empty(),
4810 "should have pending completion after tool use"
4811 );
4812
4813 let last_completion = pending.last().unwrap();
4814 let has_tool_result = last_completion.messages.iter().any(|m| {
4815 m.content
4816 .iter()
4817 .any(|c| matches!(c, MessageContent::ToolResult(_)))
4818 });
4819 assert!(
4820 has_tool_result,
4821 "tool result should be in the messages sent back to the model"
4822 );
4823}
4824
4825#[gpui::test]
4826async fn test_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4827 init_test(cx);
4828
4829 cx.update(|cx| {
4830 cx.update_flags(true, vec!["subagents".to_string()]);
4831 });
4832
4833 let fs = FakeFs::new(cx.executor());
4834 fs.insert_tree(path!("/test"), json!({})).await;
4835 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4836 let project_context = cx.new(|_cx| ProjectContext::default());
4837 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4838 let context_server_registry =
4839 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4840 let model = Arc::new(FakeLanguageModel::default());
4841
4842 let parent = cx.new(|cx| {
4843 Thread::new(
4844 project.clone(),
4845 project_context.clone(),
4846 context_server_registry.clone(),
4847 Templates::new(),
4848 Some(model.clone()),
4849 cx,
4850 )
4851 });
4852
4853 let mut subagents = Vec::new();
4854 for i in 0..MAX_PARALLEL_SUBAGENTS {
4855 let subagent_context = SubagentContext {
4856 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4857 tool_use_id: language_model::LanguageModelToolUseId::from(format!("tool-use-{}", i)),
4858 depth: 1,
4859 summary_prompt: "Summarize".to_string(),
4860 context_low_prompt: "Context low".to_string(),
4861 };
4862
4863 let subagent = cx.new(|cx| {
4864 Thread::new_subagent(
4865 project.clone(),
4866 project_context.clone(),
4867 context_server_registry.clone(),
4868 Templates::new(),
4869 model.clone(),
4870 subagent_context,
4871 std::collections::BTreeMap::new(),
4872 cx,
4873 )
4874 });
4875
4876 parent.update(cx, |thread, _cx| {
4877 thread.register_running_subagent(subagent.downgrade());
4878 });
4879 subagents.push(subagent);
4880 }
4881
4882 parent.read_with(cx, |thread, _| {
4883 assert_eq!(
4884 thread.running_subagent_count(),
4885 MAX_PARALLEL_SUBAGENTS,
4886 "should have MAX_PARALLEL_SUBAGENTS registered"
4887 );
4888 });
4889
4890 #[allow(clippy::arc_with_non_send_sync)]
4891 let tool = Arc::new(SubagentTool::new(parent.downgrade(), 0));
4892
4893 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4894
4895 let result = cx.update(|cx| {
4896 tool.run(
4897 SubagentToolInput {
4898 label: "Test".to_string(),
4899 task_prompt: "Do something".to_string(),
4900 summary_prompt: "Summarize".to_string(),
4901 context_low_prompt: "Context low".to_string(),
4902 timeout_ms: None,
4903 allowed_tools: None,
4904 },
4905 event_stream,
4906 cx,
4907 )
4908 });
4909
4910 let err = result.await.unwrap_err();
4911 assert!(
4912 err.to_string().contains("Maximum parallel subagents"),
4913 "should reject when max parallel subagents reached: {}",
4914 err
4915 );
4916
4917 drop(subagents);
4918}
4919
4920#[gpui::test]
4921async fn test_subagent_tool_end_to_end(cx: &mut TestAppContext) {
4922 init_test(cx);
4923 always_allow_tools(cx);
4924
4925 cx.update(|cx| {
4926 cx.update_flags(true, vec!["subagents".to_string()]);
4927 });
4928
4929 let fs = FakeFs::new(cx.executor());
4930 fs.insert_tree(path!("/test"), json!({})).await;
4931 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4932 let project_context = cx.new(|_cx| ProjectContext::default());
4933 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4934 let context_server_registry =
4935 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4936 let model = Arc::new(FakeLanguageModel::default());
4937 let fake_model = model.as_fake();
4938
4939 let parent = cx.new(|cx| {
4940 let mut thread = Thread::new(
4941 project.clone(),
4942 project_context.clone(),
4943 context_server_registry.clone(),
4944 Templates::new(),
4945 Some(model.clone()),
4946 cx,
4947 );
4948 thread.add_tool(EchoTool);
4949 thread
4950 });
4951
4952 #[allow(clippy::arc_with_non_send_sync)]
4953 let tool = Arc::new(SubagentTool::new(parent.downgrade(), 0));
4954
4955 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4956
4957 let task = cx.update(|cx| {
4958 tool.run(
4959 SubagentToolInput {
4960 label: "Research task".to_string(),
4961 task_prompt: "Find all TODOs in the codebase".to_string(),
4962 summary_prompt: "Summarize what you found".to_string(),
4963 context_low_prompt: "Context low, wrap up".to_string(),
4964 timeout_ms: None,
4965 allowed_tools: None,
4966 },
4967 event_stream,
4968 cx,
4969 )
4970 });
4971
4972 cx.run_until_parked();
4973
4974 let pending = fake_model.pending_completions();
4975 assert!(
4976 !pending.is_empty(),
4977 "subagent should have started and sent a completion request"
4978 );
4979
4980 let first_completion = &pending[0];
4981 let has_task_prompt = first_completion.messages.iter().any(|m| {
4982 m.role == language_model::Role::User
4983 && m.content
4984 .iter()
4985 .any(|c| c.to_str().map(|s| s.contains("TODO")).unwrap_or(false))
4986 });
4987 assert!(has_task_prompt, "task prompt should be sent to subagent");
4988
4989 fake_model.send_last_completion_stream_text_chunk("I found 5 TODOs in the codebase.");
4990 fake_model
4991 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4992 fake_model.end_last_completion_stream();
4993 cx.run_until_parked();
4994
4995 let pending = fake_model.pending_completions();
4996 assert!(
4997 !pending.is_empty(),
4998 "should have pending completion for summary request"
4999 );
5000
5001 let last_completion = pending.last().unwrap();
5002 let has_summary_prompt = last_completion.messages.iter().any(|m| {
5003 m.role == language_model::Role::User
5004 && m.content.iter().any(|c| {
5005 c.to_str()
5006 .map(|s| s.contains("Summarize") || s.contains("summarize"))
5007 .unwrap_or(false)
5008 })
5009 });
5010 assert!(
5011 has_summary_prompt,
5012 "summary prompt should be sent after task completion"
5013 );
5014
5015 fake_model.send_last_completion_stream_text_chunk("Summary: Found 5 TODOs across 3 files.");
5016 fake_model
5017 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
5018 fake_model.end_last_completion_stream();
5019 cx.run_until_parked();
5020
5021 let result = task.await;
5022 assert!(result.is_ok(), "subagent tool should complete successfully");
5023
5024 let summary = result.unwrap();
5025 assert!(
5026 summary.contains("Summary") || summary.contains("TODO") || summary.contains("5"),
5027 "summary should contain subagent's response: {}",
5028 summary
5029 );
5030}
5031
5032#[gpui::test]
5033async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5034 init_test(cx);
5035
5036 let fs = FakeFs::new(cx.executor());
5037 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5038 .await;
5039 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5040
5041 cx.update(|cx| {
5042 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5043 settings.tool_permissions.tools.insert(
5044 "edit_file".into(),
5045 agent_settings::ToolRules {
5046 default_mode: settings::ToolPermissionMode::Allow,
5047 always_allow: vec![],
5048 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5049 always_confirm: vec![],
5050 invalid_patterns: vec![],
5051 },
5052 );
5053 agent_settings::AgentSettings::override_global(settings, cx);
5054 });
5055
5056 let context_server_registry =
5057 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5058 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5059 let templates = crate::Templates::new();
5060 let thread = cx.new(|cx| {
5061 crate::Thread::new(
5062 project.clone(),
5063 cx.new(|_cx| prompt_store::ProjectContext::default()),
5064 context_server_registry,
5065 templates.clone(),
5066 None,
5067 cx,
5068 )
5069 });
5070
5071 #[allow(clippy::arc_with_non_send_sync)]
5072 let tool = Arc::new(crate::EditFileTool::new(
5073 project.clone(),
5074 thread.downgrade(),
5075 language_registry,
5076 templates,
5077 ));
5078 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5079
5080 let task = cx.update(|cx| {
5081 tool.run(
5082 crate::EditFileToolInput {
5083 display_description: "Edit sensitive file".to_string(),
5084 path: "root/sensitive_config.txt".into(),
5085 mode: crate::EditFileMode::Edit,
5086 },
5087 event_stream,
5088 cx,
5089 )
5090 });
5091
5092 let result = task.await;
5093 assert!(result.is_err(), "expected edit to be blocked");
5094 assert!(
5095 result.unwrap_err().to_string().contains("blocked"),
5096 "error should mention the edit was blocked"
5097 );
5098}
5099
5100#[gpui::test]
5101async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5102 init_test(cx);
5103
5104 let fs = FakeFs::new(cx.executor());
5105 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5106 .await;
5107 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5108
5109 cx.update(|cx| {
5110 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5111 settings.tool_permissions.tools.insert(
5112 "delete_path".into(),
5113 agent_settings::ToolRules {
5114 default_mode: settings::ToolPermissionMode::Allow,
5115 always_allow: vec![],
5116 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5117 always_confirm: vec![],
5118 invalid_patterns: vec![],
5119 },
5120 );
5121 agent_settings::AgentSettings::override_global(settings, cx);
5122 });
5123
5124 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5125
5126 #[allow(clippy::arc_with_non_send_sync)]
5127 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5128 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5129
5130 let task = cx.update(|cx| {
5131 tool.run(
5132 crate::DeletePathToolInput {
5133 path: "root/important_data.txt".to_string(),
5134 },
5135 event_stream,
5136 cx,
5137 )
5138 });
5139
5140 let result = task.await;
5141 assert!(result.is_err(), "expected deletion to be blocked");
5142 assert!(
5143 result.unwrap_err().to_string().contains("blocked"),
5144 "error should mention the deletion was blocked"
5145 );
5146}
5147
5148#[gpui::test]
5149async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5150 init_test(cx);
5151
5152 let fs = FakeFs::new(cx.executor());
5153 fs.insert_tree(
5154 "/root",
5155 json!({
5156 "safe.txt": "content",
5157 "protected": {}
5158 }),
5159 )
5160 .await;
5161 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5162
5163 cx.update(|cx| {
5164 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5165 settings.tool_permissions.tools.insert(
5166 "move_path".into(),
5167 agent_settings::ToolRules {
5168 default_mode: settings::ToolPermissionMode::Allow,
5169 always_allow: vec![],
5170 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5171 always_confirm: vec![],
5172 invalid_patterns: vec![],
5173 },
5174 );
5175 agent_settings::AgentSettings::override_global(settings, cx);
5176 });
5177
5178 #[allow(clippy::arc_with_non_send_sync)]
5179 let tool = Arc::new(crate::MovePathTool::new(project));
5180 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5181
5182 let task = cx.update(|cx| {
5183 tool.run(
5184 crate::MovePathToolInput {
5185 source_path: "root/safe.txt".to_string(),
5186 destination_path: "root/protected/safe.txt".to_string(),
5187 },
5188 event_stream,
5189 cx,
5190 )
5191 });
5192
5193 let result = task.await;
5194 assert!(
5195 result.is_err(),
5196 "expected move to be blocked due to destination path"
5197 );
5198 assert!(
5199 result.unwrap_err().to_string().contains("blocked"),
5200 "error should mention the move was blocked"
5201 );
5202}
5203
5204#[gpui::test]
5205async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5206 init_test(cx);
5207
5208 let fs = FakeFs::new(cx.executor());
5209 fs.insert_tree(
5210 "/root",
5211 json!({
5212 "secret.txt": "secret content",
5213 "public": {}
5214 }),
5215 )
5216 .await;
5217 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5218
5219 cx.update(|cx| {
5220 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5221 settings.tool_permissions.tools.insert(
5222 "move_path".into(),
5223 agent_settings::ToolRules {
5224 default_mode: settings::ToolPermissionMode::Allow,
5225 always_allow: vec![],
5226 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5227 always_confirm: vec![],
5228 invalid_patterns: vec![],
5229 },
5230 );
5231 agent_settings::AgentSettings::override_global(settings, cx);
5232 });
5233
5234 #[allow(clippy::arc_with_non_send_sync)]
5235 let tool = Arc::new(crate::MovePathTool::new(project));
5236 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5237
5238 let task = cx.update(|cx| {
5239 tool.run(
5240 crate::MovePathToolInput {
5241 source_path: "root/secret.txt".to_string(),
5242 destination_path: "root/public/not_secret.txt".to_string(),
5243 },
5244 event_stream,
5245 cx,
5246 )
5247 });
5248
5249 let result = task.await;
5250 assert!(
5251 result.is_err(),
5252 "expected move to be blocked due to source path"
5253 );
5254 assert!(
5255 result.unwrap_err().to_string().contains("blocked"),
5256 "error should mention the move was blocked"
5257 );
5258}
5259
5260#[gpui::test]
5261async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5262 init_test(cx);
5263
5264 let fs = FakeFs::new(cx.executor());
5265 fs.insert_tree(
5266 "/root",
5267 json!({
5268 "confidential.txt": "confidential data",
5269 "dest": {}
5270 }),
5271 )
5272 .await;
5273 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5274
5275 cx.update(|cx| {
5276 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5277 settings.tool_permissions.tools.insert(
5278 "copy_path".into(),
5279 agent_settings::ToolRules {
5280 default_mode: settings::ToolPermissionMode::Allow,
5281 always_allow: vec![],
5282 always_deny: vec![
5283 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5284 ],
5285 always_confirm: vec![],
5286 invalid_patterns: vec![],
5287 },
5288 );
5289 agent_settings::AgentSettings::override_global(settings, cx);
5290 });
5291
5292 #[allow(clippy::arc_with_non_send_sync)]
5293 let tool = Arc::new(crate::CopyPathTool::new(project));
5294 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5295
5296 let task = cx.update(|cx| {
5297 tool.run(
5298 crate::CopyPathToolInput {
5299 source_path: "root/confidential.txt".to_string(),
5300 destination_path: "root/dest/copy.txt".to_string(),
5301 },
5302 event_stream,
5303 cx,
5304 )
5305 });
5306
5307 let result = task.await;
5308 assert!(result.is_err(), "expected copy to be blocked");
5309 assert!(
5310 result.unwrap_err().to_string().contains("blocked"),
5311 "error should mention the copy was blocked"
5312 );
5313}
5314
5315#[gpui::test]
5316async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5317 init_test(cx);
5318
5319 let fs = FakeFs::new(cx.executor());
5320 fs.insert_tree(
5321 "/root",
5322 json!({
5323 "normal.txt": "normal content",
5324 "readonly": {
5325 "config.txt": "readonly content"
5326 }
5327 }),
5328 )
5329 .await;
5330 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5331
5332 cx.update(|cx| {
5333 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5334 settings.tool_permissions.tools.insert(
5335 "save_file".into(),
5336 agent_settings::ToolRules {
5337 default_mode: settings::ToolPermissionMode::Allow,
5338 always_allow: vec![],
5339 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5340 always_confirm: vec![],
5341 invalid_patterns: vec![],
5342 },
5343 );
5344 agent_settings::AgentSettings::override_global(settings, cx);
5345 });
5346
5347 #[allow(clippy::arc_with_non_send_sync)]
5348 let tool = Arc::new(crate::SaveFileTool::new(project));
5349 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5350
5351 let task = cx.update(|cx| {
5352 tool.run(
5353 crate::SaveFileToolInput {
5354 paths: vec![
5355 std::path::PathBuf::from("root/normal.txt"),
5356 std::path::PathBuf::from("root/readonly/config.txt"),
5357 ],
5358 },
5359 event_stream,
5360 cx,
5361 )
5362 });
5363
5364 let result = task.await;
5365 assert!(
5366 result.is_err(),
5367 "expected save to be blocked due to denied path"
5368 );
5369 assert!(
5370 result.unwrap_err().to_string().contains("blocked"),
5371 "error should mention the save was blocked"
5372 );
5373}
5374
5375#[gpui::test]
5376async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5377 init_test(cx);
5378
5379 let fs = FakeFs::new(cx.executor());
5380 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5381 .await;
5382 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5383
5384 cx.update(|cx| {
5385 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5386 settings.always_allow_tool_actions = false;
5387 settings.tool_permissions.tools.insert(
5388 "save_file".into(),
5389 agent_settings::ToolRules {
5390 default_mode: settings::ToolPermissionMode::Allow,
5391 always_allow: vec![],
5392 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5393 always_confirm: vec![],
5394 invalid_patterns: vec![],
5395 },
5396 );
5397 agent_settings::AgentSettings::override_global(settings, cx);
5398 });
5399
5400 #[allow(clippy::arc_with_non_send_sync)]
5401 let tool = Arc::new(crate::SaveFileTool::new(project));
5402 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5403
5404 let task = cx.update(|cx| {
5405 tool.run(
5406 crate::SaveFileToolInput {
5407 paths: vec![std::path::PathBuf::from("root/config.secret")],
5408 },
5409 event_stream,
5410 cx,
5411 )
5412 });
5413
5414 let result = task.await;
5415 assert!(result.is_err(), "expected save to be blocked");
5416 assert!(
5417 result.unwrap_err().to_string().contains("blocked"),
5418 "error should mention the save was blocked"
5419 );
5420}
5421
5422#[gpui::test]
5423async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5424 init_test(cx);
5425
5426 cx.update(|cx| {
5427 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5428 settings.tool_permissions.tools.insert(
5429 "web_search".into(),
5430 agent_settings::ToolRules {
5431 default_mode: settings::ToolPermissionMode::Allow,
5432 always_allow: vec![],
5433 always_deny: vec![
5434 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5435 ],
5436 always_confirm: vec![],
5437 invalid_patterns: vec![],
5438 },
5439 );
5440 agent_settings::AgentSettings::override_global(settings, cx);
5441 });
5442
5443 #[allow(clippy::arc_with_non_send_sync)]
5444 let tool = Arc::new(crate::WebSearchTool);
5445 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5446
5447 let input: crate::WebSearchToolInput =
5448 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5449
5450 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5451
5452 let result = task.await;
5453 assert!(result.is_err(), "expected search to be blocked");
5454 assert!(
5455 result.unwrap_err().to_string().contains("blocked"),
5456 "error should mention the search was blocked"
5457 );
5458}
5459
5460#[gpui::test]
5461async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5462 init_test(cx);
5463
5464 let fs = FakeFs::new(cx.executor());
5465 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5466 .await;
5467 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5468
5469 cx.update(|cx| {
5470 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5471 settings.always_allow_tool_actions = false;
5472 settings.tool_permissions.tools.insert(
5473 "edit_file".into(),
5474 agent_settings::ToolRules {
5475 default_mode: settings::ToolPermissionMode::Confirm,
5476 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5477 always_deny: vec![],
5478 always_confirm: vec![],
5479 invalid_patterns: vec![],
5480 },
5481 );
5482 agent_settings::AgentSettings::override_global(settings, cx);
5483 });
5484
5485 let context_server_registry =
5486 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5487 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5488 let templates = crate::Templates::new();
5489 let thread = cx.new(|cx| {
5490 crate::Thread::new(
5491 project.clone(),
5492 cx.new(|_cx| prompt_store::ProjectContext::default()),
5493 context_server_registry,
5494 templates.clone(),
5495 None,
5496 cx,
5497 )
5498 });
5499
5500 #[allow(clippy::arc_with_non_send_sync)]
5501 let tool = Arc::new(crate::EditFileTool::new(
5502 project,
5503 thread.downgrade(),
5504 language_registry,
5505 templates,
5506 ));
5507 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5508
5509 let _task = cx.update(|cx| {
5510 tool.run(
5511 crate::EditFileToolInput {
5512 display_description: "Edit README".to_string(),
5513 path: "root/README.md".into(),
5514 mode: crate::EditFileMode::Edit,
5515 },
5516 event_stream,
5517 cx,
5518 )
5519 });
5520
5521 cx.run_until_parked();
5522
5523 let event = rx.try_next();
5524 assert!(
5525 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5526 "expected no authorization request for allowed .md file"
5527 );
5528}
5529
5530#[gpui::test]
5531async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5532 init_test(cx);
5533
5534 cx.update(|cx| {
5535 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5536 settings.tool_permissions.tools.insert(
5537 "fetch".into(),
5538 agent_settings::ToolRules {
5539 default_mode: settings::ToolPermissionMode::Allow,
5540 always_allow: vec![],
5541 always_deny: vec![
5542 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5543 ],
5544 always_confirm: vec![],
5545 invalid_patterns: vec![],
5546 },
5547 );
5548 agent_settings::AgentSettings::override_global(settings, cx);
5549 });
5550
5551 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5552
5553 #[allow(clippy::arc_with_non_send_sync)]
5554 let tool = Arc::new(crate::FetchTool::new(http_client));
5555 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5556
5557 let input: crate::FetchToolInput =
5558 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5559
5560 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5561
5562 let result = task.await;
5563 assert!(result.is_err(), "expected fetch to be blocked");
5564 assert!(
5565 result.unwrap_err().to_string().contains("blocked"),
5566 "error should mention the fetch was blocked"
5567 );
5568}
5569
5570#[gpui::test]
5571async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5572 init_test(cx);
5573
5574 cx.update(|cx| {
5575 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5576 settings.always_allow_tool_actions = false;
5577 settings.tool_permissions.tools.insert(
5578 "fetch".into(),
5579 agent_settings::ToolRules {
5580 default_mode: settings::ToolPermissionMode::Confirm,
5581 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5582 always_deny: vec![],
5583 always_confirm: vec![],
5584 invalid_patterns: vec![],
5585 },
5586 );
5587 agent_settings::AgentSettings::override_global(settings, cx);
5588 });
5589
5590 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5591
5592 #[allow(clippy::arc_with_non_send_sync)]
5593 let tool = Arc::new(crate::FetchTool::new(http_client));
5594 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5595
5596 let input: crate::FetchToolInput =
5597 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5598
5599 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5600
5601 cx.run_until_parked();
5602
5603 let event = rx.try_next();
5604 assert!(
5605 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5606 "expected no authorization request for allowed docs.rs URL"
5607 );
5608}
5609
5610#[gpui::test]
5611async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5612 init_test(cx);
5613 always_allow_tools(cx);
5614
5615 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5616 let fake_model = model.as_fake();
5617
5618 // Add a tool so we can simulate tool calls
5619 thread.update(cx, |thread, _cx| {
5620 thread.add_tool(EchoTool);
5621 });
5622
5623 // Start a turn by sending a message
5624 let mut events = thread
5625 .update(cx, |thread, cx| {
5626 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5627 })
5628 .unwrap();
5629 cx.run_until_parked();
5630
5631 // Simulate the model making a tool call
5632 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5633 LanguageModelToolUse {
5634 id: "tool_1".into(),
5635 name: "echo".into(),
5636 raw_input: r#"{"text": "hello"}"#.into(),
5637 input: json!({"text": "hello"}),
5638 is_input_complete: true,
5639 thought_signature: None,
5640 },
5641 ));
5642 fake_model
5643 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5644
5645 // Signal that a message is queued before ending the stream
5646 thread.update(cx, |thread, _cx| {
5647 thread.set_has_queued_message(true);
5648 });
5649
5650 // Now end the stream - tool will run, and the boundary check should see the queue
5651 fake_model.end_last_completion_stream();
5652
5653 // Collect all events until the turn stops
5654 let all_events = collect_events_until_stop(&mut events, cx).await;
5655
5656 // Verify we received the tool call event
5657 let tool_call_ids: Vec<_> = all_events
5658 .iter()
5659 .filter_map(|e| match e {
5660 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5661 _ => None,
5662 })
5663 .collect();
5664 assert_eq!(
5665 tool_call_ids,
5666 vec!["tool_1"],
5667 "Should have received a tool call event for our echo tool"
5668 );
5669
5670 // The turn should have stopped with EndTurn
5671 let stop_reasons = stop_events(all_events);
5672 assert_eq!(
5673 stop_reasons,
5674 vec![acp::StopReason::EndTurn],
5675 "Turn should have ended after tool completion due to queued message"
5676 );
5677
5678 // Verify the queued message flag is still set
5679 thread.update(cx, |thread, _cx| {
5680 assert!(
5681 thread.has_queued_message(),
5682 "Should still have queued message flag set"
5683 );
5684 });
5685
5686 // Thread should be idle now
5687 thread.update(cx, |thread, _cx| {
5688 assert!(
5689 thread.is_turn_complete(),
5690 "Thread should not be running after turn ends"
5691 );
5692 });
5693}