1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, ThreadStatus,
4 UserMessageId,
5};
6use agent_client_protocol::{self as acp};
7use agent_settings::AgentProfileId;
8use anyhow::Result;
9use client::{Client, UserStore};
10use cloud_llm_client::CompletionIntent;
11use collections::IndexMap;
12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
13use feature_flags::FeatureFlagAppExt as _;
14use fs::{FakeFs, Fs};
15use futures::{
16 FutureExt as _, StreamExt,
17 channel::{
18 mpsc::{self, UnboundedReceiver},
19 oneshot,
20 },
21 future::{Fuse, Shared},
22};
23use gpui::{
24 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
25 http_client::FakeHttpClient,
26};
27use indoc::indoc;
28use language_model::{
29 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
30 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
31 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
32 LanguageModelToolUse, MessageContent, Role, StopReason, TokenUsage,
33 fake_provider::FakeLanguageModel,
34};
35use pretty_assertions::assert_eq;
36use project::{
37 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
38};
39use prompt_store::ProjectContext;
40use reqwest_client::ReqwestClient;
41use schemars::JsonSchema;
42use serde::{Deserialize, Serialize};
43use serde_json::json;
44use settings::{Settings, SettingsStore};
45use std::{
46 path::Path,
47 pin::Pin,
48 rc::Rc,
49 sync::{
50 Arc,
51 atomic::{AtomicBool, Ordering},
52 },
53 time::Duration,
54};
55use util::path;
56
57mod edit_file_thread_test;
58mod test_tools;
59use test_tools::*;
60
61fn init_test(cx: &mut TestAppContext) {
62 cx.update(|cx| {
63 let settings_store = SettingsStore::test(cx);
64 cx.set_global(settings_store);
65 });
66}
67
68struct FakeTerminalHandle {
69 killed: Arc<AtomicBool>,
70 stopped_by_user: Arc<AtomicBool>,
71 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
72 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
73 output: acp::TerminalOutputResponse,
74 id: acp::TerminalId,
75}
76
77impl FakeTerminalHandle {
78 fn new_never_exits(cx: &mut App) -> Self {
79 let killed = Arc::new(AtomicBool::new(false));
80 let stopped_by_user = Arc::new(AtomicBool::new(false));
81
82 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
83
84 let wait_for_exit = cx
85 .spawn(async move |_cx| {
86 // Wait for the exit signal (sent when kill() is called)
87 let _ = exit_receiver.await;
88 acp::TerminalExitStatus::new()
89 })
90 .shared();
91
92 Self {
93 killed,
94 stopped_by_user,
95 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
96 wait_for_exit,
97 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
98 id: acp::TerminalId::new("fake_terminal".to_string()),
99 }
100 }
101
102 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
103 let killed = Arc::new(AtomicBool::new(false));
104 let stopped_by_user = Arc::new(AtomicBool::new(false));
105 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
106
107 let wait_for_exit = cx
108 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
109 .shared();
110
111 Self {
112 killed,
113 stopped_by_user,
114 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
115 wait_for_exit,
116 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
117 id: acp::TerminalId::new("fake_terminal".to_string()),
118 }
119 }
120
121 fn was_killed(&self) -> bool {
122 self.killed.load(Ordering::SeqCst)
123 }
124
125 fn set_stopped_by_user(&self, stopped: bool) {
126 self.stopped_by_user.store(stopped, Ordering::SeqCst);
127 }
128
129 fn signal_exit(&self) {
130 if let Some(sender) = self.exit_sender.borrow_mut().take() {
131 let _ = sender.send(());
132 }
133 }
134}
135
136impl crate::TerminalHandle for FakeTerminalHandle {
137 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
138 Ok(self.id.clone())
139 }
140
141 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
142 Ok(self.output.clone())
143 }
144
145 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
146 Ok(self.wait_for_exit.clone())
147 }
148
149 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
150 self.killed.store(true, Ordering::SeqCst);
151 self.signal_exit();
152 Ok(())
153 }
154
155 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
156 Ok(self.stopped_by_user.load(Ordering::SeqCst))
157 }
158}
159
160struct FakeSubagentHandle {
161 session_id: acp::SessionId,
162 send_task: Shared<Task<String>>,
163}
164
165impl SubagentHandle for FakeSubagentHandle {
166 fn id(&self) -> acp::SessionId {
167 self.session_id.clone()
168 }
169
170 fn num_entries(&self, _cx: &App) -> usize {
171 unimplemented!()
172 }
173
174 fn send(&self, _message: String, cx: &AsyncApp) -> Task<Result<String>> {
175 let task = self.send_task.clone();
176 cx.background_spawn(async move { Ok(task.await) })
177 }
178}
179
180#[derive(Default)]
181struct FakeThreadEnvironment {
182 terminal_handle: Option<Rc<FakeTerminalHandle>>,
183 subagent_handle: Option<Rc<FakeSubagentHandle>>,
184}
185
186impl FakeThreadEnvironment {
187 pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
188 Self {
189 terminal_handle: Some(terminal_handle.into()),
190 ..self
191 }
192 }
193}
194
195impl crate::ThreadEnvironment for FakeThreadEnvironment {
196 fn create_terminal(
197 &self,
198 _command: String,
199 _cwd: Option<std::path::PathBuf>,
200 _output_byte_limit: Option<u64>,
201 _cx: &mut AsyncApp,
202 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
203 let handle = self
204 .terminal_handle
205 .clone()
206 .expect("Terminal handle not available on FakeThreadEnvironment");
207 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
208 }
209
210 fn create_subagent(&self, _label: String, _cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
211 Ok(self
212 .subagent_handle
213 .clone()
214 .expect("Subagent handle not available on FakeThreadEnvironment")
215 as Rc<dyn SubagentHandle>)
216 }
217}
218
219/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
220struct MultiTerminalEnvironment {
221 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
222}
223
224impl MultiTerminalEnvironment {
225 fn new() -> Self {
226 Self {
227 handles: std::cell::RefCell::new(Vec::new()),
228 }
229 }
230
231 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
232 self.handles.borrow().clone()
233 }
234}
235
236impl crate::ThreadEnvironment for MultiTerminalEnvironment {
237 fn create_terminal(
238 &self,
239 _command: String,
240 _cwd: Option<std::path::PathBuf>,
241 _output_byte_limit: Option<u64>,
242 cx: &mut AsyncApp,
243 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
244 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
245 self.handles.borrow_mut().push(handle.clone());
246 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
247 }
248
249 fn create_subagent(&self, _label: String, _cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
250 unimplemented!()
251 }
252}
253
254fn always_allow_tools(cx: &mut TestAppContext) {
255 cx.update(|cx| {
256 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
257 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
258 agent_settings::AgentSettings::override_global(settings, cx);
259 });
260}
261
262#[gpui::test]
263async fn test_echo(cx: &mut TestAppContext) {
264 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
265 let fake_model = model.as_fake();
266
267 let events = thread
268 .update(cx, |thread, cx| {
269 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
270 })
271 .unwrap();
272 cx.run_until_parked();
273 fake_model.send_last_completion_stream_text_chunk("Hello");
274 fake_model
275 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
276 fake_model.end_last_completion_stream();
277
278 let events = events.collect().await;
279 thread.update(cx, |thread, _cx| {
280 assert_eq!(
281 thread.last_received_or_pending_message().unwrap().role(),
282 Role::Assistant
283 );
284 assert_eq!(
285 thread
286 .last_received_or_pending_message()
287 .unwrap()
288 .to_markdown(),
289 "Hello\n"
290 )
291 });
292 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
293}
294
295#[gpui::test]
296async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
297 init_test(cx);
298 always_allow_tools(cx);
299
300 let fs = FakeFs::new(cx.executor());
301 let project = Project::test(fs, [], cx).await;
302
303 let environment = Rc::new(cx.update(|cx| {
304 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
305 }));
306 let handle = environment.terminal_handle.clone().unwrap();
307
308 #[allow(clippy::arc_with_non_send_sync)]
309 let tool = Arc::new(crate::TerminalTool::new(project, environment));
310 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
311
312 let task = cx.update(|cx| {
313 tool.run(
314 ToolInput::resolved(crate::TerminalToolInput {
315 command: "sleep 1000".to_string(),
316 cd: ".".to_string(),
317 timeout_ms: Some(5),
318 }),
319 event_stream,
320 cx,
321 )
322 });
323
324 let update = rx.expect_update_fields().await;
325 assert!(
326 update.content.iter().any(|blocks| {
327 blocks
328 .iter()
329 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
330 }),
331 "expected tool call update to include terminal content"
332 );
333
334 let mut task_future: Pin<Box<Fuse<Task<Result<String, String>>>>> = Box::pin(task.fuse());
335
336 let deadline = std::time::Instant::now() + Duration::from_millis(500);
337 loop {
338 if let Some(result) = task_future.as_mut().now_or_never() {
339 let result = result.expect("terminal tool task should complete");
340
341 assert!(
342 handle.was_killed(),
343 "expected terminal handle to be killed on timeout"
344 );
345 assert!(
346 result.contains("partial output"),
347 "expected result to include terminal output, got: {result}"
348 );
349 return;
350 }
351
352 if std::time::Instant::now() >= deadline {
353 panic!("timed out waiting for terminal tool task to complete");
354 }
355
356 cx.run_until_parked();
357 cx.background_executor.timer(Duration::from_millis(1)).await;
358 }
359}
360
361#[gpui::test]
362#[ignore]
363async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
364 init_test(cx);
365 always_allow_tools(cx);
366
367 let fs = FakeFs::new(cx.executor());
368 let project = Project::test(fs, [], cx).await;
369
370 let environment = Rc::new(cx.update(|cx| {
371 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
372 }));
373 let handle = environment.terminal_handle.clone().unwrap();
374
375 #[allow(clippy::arc_with_non_send_sync)]
376 let tool = Arc::new(crate::TerminalTool::new(project, environment));
377 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
378
379 let _task = cx.update(|cx| {
380 tool.run(
381 ToolInput::resolved(crate::TerminalToolInput {
382 command: "sleep 1000".to_string(),
383 cd: ".".to_string(),
384 timeout_ms: None,
385 }),
386 event_stream,
387 cx,
388 )
389 });
390
391 let update = rx.expect_update_fields().await;
392 assert!(
393 update.content.iter().any(|blocks| {
394 blocks
395 .iter()
396 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
397 }),
398 "expected tool call update to include terminal content"
399 );
400
401 cx.background_executor
402 .timer(Duration::from_millis(25))
403 .await;
404
405 assert!(
406 !handle.was_killed(),
407 "did not expect terminal handle to be killed without a timeout"
408 );
409}
410
411#[gpui::test]
412async fn test_thinking(cx: &mut TestAppContext) {
413 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
414 let fake_model = model.as_fake();
415
416 let events = thread
417 .update(cx, |thread, cx| {
418 thread.send(
419 UserMessageId::new(),
420 [indoc! {"
421 Testing:
422
423 Generate a thinking step where you just think the word 'Think',
424 and have your final answer be 'Hello'
425 "}],
426 cx,
427 )
428 })
429 .unwrap();
430 cx.run_until_parked();
431 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
432 text: "Think".to_string(),
433 signature: None,
434 });
435 fake_model.send_last_completion_stream_text_chunk("Hello");
436 fake_model
437 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
438 fake_model.end_last_completion_stream();
439
440 let events = events.collect().await;
441 thread.update(cx, |thread, _cx| {
442 assert_eq!(
443 thread.last_received_or_pending_message().unwrap().role(),
444 Role::Assistant
445 );
446 assert_eq!(
447 thread
448 .last_received_or_pending_message()
449 .unwrap()
450 .to_markdown(),
451 indoc! {"
452 <think>Think</think>
453 Hello
454 "}
455 )
456 });
457 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
458}
459
460#[gpui::test]
461async fn test_system_prompt(cx: &mut TestAppContext) {
462 let ThreadTest {
463 model,
464 thread,
465 project_context,
466 ..
467 } = setup(cx, TestModel::Fake).await;
468 let fake_model = model.as_fake();
469
470 project_context.update(cx, |project_context, _cx| {
471 project_context.shell = "test-shell".into()
472 });
473 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
474 thread
475 .update(cx, |thread, cx| {
476 thread.send(UserMessageId::new(), ["abc"], cx)
477 })
478 .unwrap();
479 cx.run_until_parked();
480 let mut pending_completions = fake_model.pending_completions();
481 assert_eq!(
482 pending_completions.len(),
483 1,
484 "unexpected pending completions: {:?}",
485 pending_completions
486 );
487
488 let pending_completion = pending_completions.pop().unwrap();
489 assert_eq!(pending_completion.messages[0].role, Role::System);
490
491 let system_message = &pending_completion.messages[0];
492 let system_prompt = system_message.content[0].to_str().unwrap();
493 assert!(
494 system_prompt.contains("test-shell"),
495 "unexpected system message: {:?}",
496 system_message
497 );
498 assert!(
499 system_prompt.contains("## Fixing Diagnostics"),
500 "unexpected system message: {:?}",
501 system_message
502 );
503}
504
505#[gpui::test]
506async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
507 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
508 let fake_model = model.as_fake();
509
510 thread
511 .update(cx, |thread, cx| {
512 thread.send(UserMessageId::new(), ["abc"], cx)
513 })
514 .unwrap();
515 cx.run_until_parked();
516 let mut pending_completions = fake_model.pending_completions();
517 assert_eq!(
518 pending_completions.len(),
519 1,
520 "unexpected pending completions: {:?}",
521 pending_completions
522 );
523
524 let pending_completion = pending_completions.pop().unwrap();
525 assert_eq!(pending_completion.messages[0].role, Role::System);
526
527 let system_message = &pending_completion.messages[0];
528 let system_prompt = system_message.content[0].to_str().unwrap();
529 assert!(
530 !system_prompt.contains("## Tool Use"),
531 "unexpected system message: {:?}",
532 system_message
533 );
534 assert!(
535 !system_prompt.contains("## Fixing Diagnostics"),
536 "unexpected system message: {:?}",
537 system_message
538 );
539}
540
541#[gpui::test]
542async fn test_prompt_caching(cx: &mut TestAppContext) {
543 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
544 let fake_model = model.as_fake();
545
546 // Send initial user message and verify it's cached
547 thread
548 .update(cx, |thread, cx| {
549 thread.send(UserMessageId::new(), ["Message 1"], cx)
550 })
551 .unwrap();
552 cx.run_until_parked();
553
554 let completion = fake_model.pending_completions().pop().unwrap();
555 assert_eq!(
556 completion.messages[1..],
557 vec![LanguageModelRequestMessage {
558 role: Role::User,
559 content: vec!["Message 1".into()],
560 cache: true,
561 reasoning_details: None,
562 }]
563 );
564 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
565 "Response to Message 1".into(),
566 ));
567 fake_model.end_last_completion_stream();
568 cx.run_until_parked();
569
570 // Send another user message and verify only the latest is cached
571 thread
572 .update(cx, |thread, cx| {
573 thread.send(UserMessageId::new(), ["Message 2"], cx)
574 })
575 .unwrap();
576 cx.run_until_parked();
577
578 let completion = fake_model.pending_completions().pop().unwrap();
579 assert_eq!(
580 completion.messages[1..],
581 vec![
582 LanguageModelRequestMessage {
583 role: Role::User,
584 content: vec!["Message 1".into()],
585 cache: false,
586 reasoning_details: None,
587 },
588 LanguageModelRequestMessage {
589 role: Role::Assistant,
590 content: vec!["Response to Message 1".into()],
591 cache: false,
592 reasoning_details: None,
593 },
594 LanguageModelRequestMessage {
595 role: Role::User,
596 content: vec!["Message 2".into()],
597 cache: true,
598 reasoning_details: None,
599 }
600 ]
601 );
602 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
603 "Response to Message 2".into(),
604 ));
605 fake_model.end_last_completion_stream();
606 cx.run_until_parked();
607
608 // Simulate a tool call and verify that the latest tool result is cached
609 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
610 thread
611 .update(cx, |thread, cx| {
612 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
613 })
614 .unwrap();
615 cx.run_until_parked();
616
617 let tool_use = LanguageModelToolUse {
618 id: "tool_1".into(),
619 name: EchoTool::NAME.into(),
620 raw_input: json!({"text": "test"}).to_string(),
621 input: json!({"text": "test"}),
622 is_input_complete: true,
623 thought_signature: None,
624 };
625 fake_model
626 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
627 fake_model.end_last_completion_stream();
628 cx.run_until_parked();
629
630 let completion = fake_model.pending_completions().pop().unwrap();
631 let tool_result = LanguageModelToolResult {
632 tool_use_id: "tool_1".into(),
633 tool_name: EchoTool::NAME.into(),
634 is_error: false,
635 content: "test".into(),
636 output: Some("test".into()),
637 };
638 assert_eq!(
639 completion.messages[1..],
640 vec![
641 LanguageModelRequestMessage {
642 role: Role::User,
643 content: vec!["Message 1".into()],
644 cache: false,
645 reasoning_details: None,
646 },
647 LanguageModelRequestMessage {
648 role: Role::Assistant,
649 content: vec!["Response to Message 1".into()],
650 cache: false,
651 reasoning_details: None,
652 },
653 LanguageModelRequestMessage {
654 role: Role::User,
655 content: vec!["Message 2".into()],
656 cache: false,
657 reasoning_details: None,
658 },
659 LanguageModelRequestMessage {
660 role: Role::Assistant,
661 content: vec!["Response to Message 2".into()],
662 cache: false,
663 reasoning_details: None,
664 },
665 LanguageModelRequestMessage {
666 role: Role::User,
667 content: vec!["Use the echo tool".into()],
668 cache: false,
669 reasoning_details: None,
670 },
671 LanguageModelRequestMessage {
672 role: Role::Assistant,
673 content: vec![MessageContent::ToolUse(tool_use)],
674 cache: false,
675 reasoning_details: None,
676 },
677 LanguageModelRequestMessage {
678 role: Role::User,
679 content: vec![MessageContent::ToolResult(tool_result)],
680 cache: true,
681 reasoning_details: None,
682 }
683 ]
684 );
685}
686
687#[gpui::test]
688#[cfg_attr(not(feature = "e2e"), ignore)]
689async fn test_basic_tool_calls(cx: &mut TestAppContext) {
690 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
691
692 // Test a tool call that's likely to complete *before* streaming stops.
693 let events = thread
694 .update(cx, |thread, cx| {
695 thread.add_tool(EchoTool);
696 thread.send(
697 UserMessageId::new(),
698 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
699 cx,
700 )
701 })
702 .unwrap()
703 .collect()
704 .await;
705 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
706
707 // Test a tool calls that's likely to complete *after* streaming stops.
708 let events = thread
709 .update(cx, |thread, cx| {
710 thread.remove_tool(&EchoTool::NAME);
711 thread.add_tool(DelayTool);
712 thread.send(
713 UserMessageId::new(),
714 [
715 "Now call the delay tool with 200ms.",
716 "When the timer goes off, then you echo the output of the tool.",
717 ],
718 cx,
719 )
720 })
721 .unwrap()
722 .collect()
723 .await;
724 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
725 thread.update(cx, |thread, _cx| {
726 assert!(
727 thread
728 .last_received_or_pending_message()
729 .unwrap()
730 .as_agent_message()
731 .unwrap()
732 .content
733 .iter()
734 .any(|content| {
735 if let AgentMessageContent::Text(text) = content {
736 text.contains("Ding")
737 } else {
738 false
739 }
740 }),
741 "{}",
742 thread.to_markdown()
743 );
744 });
745}
746
747#[gpui::test]
748#[cfg_attr(not(feature = "e2e"), ignore)]
749async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
750 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
751
752 // Test a tool call that's likely to complete *before* streaming stops.
753 let mut events = thread
754 .update(cx, |thread, cx| {
755 thread.add_tool(WordListTool);
756 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
757 })
758 .unwrap();
759
760 let mut saw_partial_tool_use = false;
761 while let Some(event) = events.next().await {
762 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
763 thread.update(cx, |thread, _cx| {
764 // Look for a tool use in the thread's last message
765 let message = thread.last_received_or_pending_message().unwrap();
766 let agent_message = message.as_agent_message().unwrap();
767 let last_content = agent_message.content.last().unwrap();
768 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
769 assert_eq!(last_tool_use.name.as_ref(), "word_list");
770 if tool_call.status == acp::ToolCallStatus::Pending {
771 if !last_tool_use.is_input_complete
772 && last_tool_use.input.get("g").is_none()
773 {
774 saw_partial_tool_use = true;
775 }
776 } else {
777 last_tool_use
778 .input
779 .get("a")
780 .expect("'a' has streamed because input is now complete");
781 last_tool_use
782 .input
783 .get("g")
784 .expect("'g' has streamed because input is now complete");
785 }
786 } else {
787 panic!("last content should be a tool use");
788 }
789 });
790 }
791 }
792
793 assert!(
794 saw_partial_tool_use,
795 "should see at least one partially streamed tool use in the history"
796 );
797}
798
799#[gpui::test]
800async fn test_tool_authorization(cx: &mut TestAppContext) {
801 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
802 let fake_model = model.as_fake();
803
804 let mut events = thread
805 .update(cx, |thread, cx| {
806 thread.add_tool(ToolRequiringPermission);
807 thread.send(UserMessageId::new(), ["abc"], cx)
808 })
809 .unwrap();
810 cx.run_until_parked();
811 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
812 LanguageModelToolUse {
813 id: "tool_id_1".into(),
814 name: ToolRequiringPermission::NAME.into(),
815 raw_input: "{}".into(),
816 input: json!({}),
817 is_input_complete: true,
818 thought_signature: None,
819 },
820 ));
821 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
822 LanguageModelToolUse {
823 id: "tool_id_2".into(),
824 name: ToolRequiringPermission::NAME.into(),
825 raw_input: "{}".into(),
826 input: json!({}),
827 is_input_complete: true,
828 thought_signature: None,
829 },
830 ));
831 fake_model.end_last_completion_stream();
832 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
833 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
834
835 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
836 tool_call_auth_1
837 .response
838 .send(acp::PermissionOptionId::new("allow"))
839 .unwrap();
840 cx.run_until_parked();
841
842 // Reject the second - send "deny" option_id directly since Deny is now a button
843 tool_call_auth_2
844 .response
845 .send(acp::PermissionOptionId::new("deny"))
846 .unwrap();
847 cx.run_until_parked();
848
849 let completion = fake_model.pending_completions().pop().unwrap();
850 let message = completion.messages.last().unwrap();
851 assert_eq!(
852 message.content,
853 vec![
854 language_model::MessageContent::ToolResult(LanguageModelToolResult {
855 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
856 tool_name: ToolRequiringPermission::NAME.into(),
857 is_error: false,
858 content: "Allowed".into(),
859 output: Some("Allowed".into())
860 }),
861 language_model::MessageContent::ToolResult(LanguageModelToolResult {
862 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
863 tool_name: ToolRequiringPermission::NAME.into(),
864 is_error: true,
865 content: "Permission to run tool denied by user".into(),
866 output: Some("Permission to run tool denied by user".into())
867 })
868 ]
869 );
870
871 // Simulate yet another tool call.
872 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
873 LanguageModelToolUse {
874 id: "tool_id_3".into(),
875 name: ToolRequiringPermission::NAME.into(),
876 raw_input: "{}".into(),
877 input: json!({}),
878 is_input_complete: true,
879 thought_signature: None,
880 },
881 ));
882 fake_model.end_last_completion_stream();
883
884 // Respond by always allowing tools - send transformed option_id
885 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
886 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
887 tool_call_auth_3
888 .response
889 .send(acp::PermissionOptionId::new(
890 "always_allow:tool_requiring_permission",
891 ))
892 .unwrap();
893 cx.run_until_parked();
894 let completion = fake_model.pending_completions().pop().unwrap();
895 let message = completion.messages.last().unwrap();
896 assert_eq!(
897 message.content,
898 vec![language_model::MessageContent::ToolResult(
899 LanguageModelToolResult {
900 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
901 tool_name: ToolRequiringPermission::NAME.into(),
902 is_error: false,
903 content: "Allowed".into(),
904 output: Some("Allowed".into())
905 }
906 )]
907 );
908
909 // Simulate a final tool call, ensuring we don't trigger authorization.
910 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
911 LanguageModelToolUse {
912 id: "tool_id_4".into(),
913 name: ToolRequiringPermission::NAME.into(),
914 raw_input: "{}".into(),
915 input: json!({}),
916 is_input_complete: true,
917 thought_signature: None,
918 },
919 ));
920 fake_model.end_last_completion_stream();
921 cx.run_until_parked();
922 let completion = fake_model.pending_completions().pop().unwrap();
923 let message = completion.messages.last().unwrap();
924 assert_eq!(
925 message.content,
926 vec![language_model::MessageContent::ToolResult(
927 LanguageModelToolResult {
928 tool_use_id: "tool_id_4".into(),
929 tool_name: ToolRequiringPermission::NAME.into(),
930 is_error: false,
931 content: "Allowed".into(),
932 output: Some("Allowed".into())
933 }
934 )]
935 );
936}
937
938#[gpui::test]
939async fn test_tool_hallucination(cx: &mut TestAppContext) {
940 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
941 let fake_model = model.as_fake();
942
943 let mut events = thread
944 .update(cx, |thread, cx| {
945 thread.send(UserMessageId::new(), ["abc"], cx)
946 })
947 .unwrap();
948 cx.run_until_parked();
949 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
950 LanguageModelToolUse {
951 id: "tool_id_1".into(),
952 name: "nonexistent_tool".into(),
953 raw_input: "{}".into(),
954 input: json!({}),
955 is_input_complete: true,
956 thought_signature: None,
957 },
958 ));
959 fake_model.end_last_completion_stream();
960
961 let tool_call = expect_tool_call(&mut events).await;
962 assert_eq!(tool_call.title, "nonexistent_tool");
963 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
964 let update = expect_tool_call_update_fields(&mut events).await;
965 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
966}
967
968async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
969 let event = events
970 .next()
971 .await
972 .expect("no tool call authorization event received")
973 .unwrap();
974 match event {
975 ThreadEvent::ToolCall(tool_call) => tool_call,
976 event => {
977 panic!("Unexpected event {event:?}");
978 }
979 }
980}
981
982async fn expect_tool_call_update_fields(
983 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
984) -> acp::ToolCallUpdate {
985 let event = events
986 .next()
987 .await
988 .expect("no tool call authorization event received")
989 .unwrap();
990 match event {
991 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
992 event => {
993 panic!("Unexpected event {event:?}");
994 }
995 }
996}
997
998async fn next_tool_call_authorization(
999 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1000) -> ToolCallAuthorization {
1001 loop {
1002 let event = events
1003 .next()
1004 .await
1005 .expect("no tool call authorization event received")
1006 .unwrap();
1007 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1008 let permission_kinds = tool_call_authorization
1009 .options
1010 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
1011 .map(|option| option.kind);
1012 let allow_once = tool_call_authorization
1013 .options
1014 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
1015 .map(|option| option.kind);
1016
1017 assert_eq!(
1018 permission_kinds,
1019 Some(acp::PermissionOptionKind::AllowAlways)
1020 );
1021 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
1022 return tool_call_authorization;
1023 }
1024 }
1025}
1026
1027#[test]
1028fn test_permission_options_terminal_with_pattern() {
1029 let permission_options = ToolPermissionContext::new(
1030 TerminalTool::NAME,
1031 vec!["cargo build --release".to_string()],
1032 )
1033 .build_permission_options();
1034
1035 let PermissionOptions::Dropdown(choices) = permission_options else {
1036 panic!("Expected dropdown permission options");
1037 };
1038
1039 assert_eq!(choices.len(), 3);
1040 let labels: Vec<&str> = choices
1041 .iter()
1042 .map(|choice| choice.allow.name.as_ref())
1043 .collect();
1044 assert!(labels.contains(&"Always for terminal"));
1045 assert!(labels.contains(&"Always for `cargo build` commands"));
1046 assert!(labels.contains(&"Only this time"));
1047}
1048
1049#[test]
1050fn test_permission_options_terminal_command_with_flag_second_token() {
1051 let permission_options =
1052 ToolPermissionContext::new(TerminalTool::NAME, vec!["ls -la".to_string()])
1053 .build_permission_options();
1054
1055 let PermissionOptions::Dropdown(choices) = permission_options else {
1056 panic!("Expected dropdown permission options");
1057 };
1058
1059 assert_eq!(choices.len(), 3);
1060 let labels: Vec<&str> = choices
1061 .iter()
1062 .map(|choice| choice.allow.name.as_ref())
1063 .collect();
1064 assert!(labels.contains(&"Always for terminal"));
1065 assert!(labels.contains(&"Always for `ls` commands"));
1066 assert!(labels.contains(&"Only this time"));
1067}
1068
1069#[test]
1070fn test_permission_options_terminal_single_word_command() {
1071 let permission_options =
1072 ToolPermissionContext::new(TerminalTool::NAME, vec!["whoami".to_string()])
1073 .build_permission_options();
1074
1075 let PermissionOptions::Dropdown(choices) = permission_options else {
1076 panic!("Expected dropdown permission options");
1077 };
1078
1079 assert_eq!(choices.len(), 3);
1080 let labels: Vec<&str> = choices
1081 .iter()
1082 .map(|choice| choice.allow.name.as_ref())
1083 .collect();
1084 assert!(labels.contains(&"Always for terminal"));
1085 assert!(labels.contains(&"Always for `whoami` commands"));
1086 assert!(labels.contains(&"Only this time"));
1087}
1088
1089#[test]
1090fn test_permission_options_edit_file_with_path_pattern() {
1091 let permission_options =
1092 ToolPermissionContext::new(EditFileTool::NAME, vec!["src/main.rs".to_string()])
1093 .build_permission_options();
1094
1095 let PermissionOptions::Dropdown(choices) = permission_options else {
1096 panic!("Expected dropdown permission options");
1097 };
1098
1099 let labels: Vec<&str> = choices
1100 .iter()
1101 .map(|choice| choice.allow.name.as_ref())
1102 .collect();
1103 assert!(labels.contains(&"Always for edit file"));
1104 assert!(labels.contains(&"Always for `src/`"));
1105}
1106
1107#[test]
1108fn test_permission_options_fetch_with_domain_pattern() {
1109 let permission_options =
1110 ToolPermissionContext::new(FetchTool::NAME, vec!["https://docs.rs/gpui".to_string()])
1111 .build_permission_options();
1112
1113 let PermissionOptions::Dropdown(choices) = permission_options else {
1114 panic!("Expected dropdown permission options");
1115 };
1116
1117 let labels: Vec<&str> = choices
1118 .iter()
1119 .map(|choice| choice.allow.name.as_ref())
1120 .collect();
1121 assert!(labels.contains(&"Always for fetch"));
1122 assert!(labels.contains(&"Always for `docs.rs`"));
1123}
1124
1125#[test]
1126fn test_permission_options_without_pattern() {
1127 let permission_options = ToolPermissionContext::new(
1128 TerminalTool::NAME,
1129 vec!["./deploy.sh --production".to_string()],
1130 )
1131 .build_permission_options();
1132
1133 let PermissionOptions::Dropdown(choices) = permission_options else {
1134 panic!("Expected dropdown permission options");
1135 };
1136
1137 assert_eq!(choices.len(), 2);
1138 let labels: Vec<&str> = choices
1139 .iter()
1140 .map(|choice| choice.allow.name.as_ref())
1141 .collect();
1142 assert!(labels.contains(&"Always for terminal"));
1143 assert!(labels.contains(&"Only this time"));
1144 assert!(!labels.iter().any(|label| label.contains("commands")));
1145}
1146
1147#[test]
1148fn test_permission_options_symlink_target_are_flat_once_only() {
1149 let permission_options =
1150 ToolPermissionContext::symlink_target(EditFileTool::NAME, vec!["/outside/file.txt".into()])
1151 .build_permission_options();
1152
1153 let PermissionOptions::Flat(options) = permission_options else {
1154 panic!("Expected flat permission options for symlink target authorization");
1155 };
1156
1157 assert_eq!(options.len(), 2);
1158 assert!(options.iter().any(|option| {
1159 option.option_id.0.as_ref() == "allow"
1160 && option.kind == acp::PermissionOptionKind::AllowOnce
1161 }));
1162 assert!(options.iter().any(|option| {
1163 option.option_id.0.as_ref() == "deny"
1164 && option.kind == acp::PermissionOptionKind::RejectOnce
1165 }));
1166}
1167
1168#[test]
1169fn test_permission_option_ids_for_terminal() {
1170 let permission_options = ToolPermissionContext::new(
1171 TerminalTool::NAME,
1172 vec!["cargo build --release".to_string()],
1173 )
1174 .build_permission_options();
1175
1176 let PermissionOptions::Dropdown(choices) = permission_options else {
1177 panic!("Expected dropdown permission options");
1178 };
1179
1180 let allow_ids: Vec<String> = choices
1181 .iter()
1182 .map(|choice| choice.allow.option_id.0.to_string())
1183 .collect();
1184 let deny_ids: Vec<String> = choices
1185 .iter()
1186 .map(|choice| choice.deny.option_id.0.to_string())
1187 .collect();
1188
1189 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1190 assert!(allow_ids.contains(&"allow".to_string()));
1191 assert!(
1192 allow_ids
1193 .iter()
1194 .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
1195 "Missing allow pattern option"
1196 );
1197
1198 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1199 assert!(deny_ids.contains(&"deny".to_string()));
1200 assert!(
1201 deny_ids
1202 .iter()
1203 .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
1204 "Missing deny pattern option"
1205 );
1206}
1207
1208#[gpui::test]
1209#[cfg_attr(not(feature = "e2e"), ignore)]
1210async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1211 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1212
1213 // Test concurrent tool calls with different delay times
1214 let events = thread
1215 .update(cx, |thread, cx| {
1216 thread.add_tool(DelayTool);
1217 thread.send(
1218 UserMessageId::new(),
1219 [
1220 "Call the delay tool twice in the same message.",
1221 "Once with 100ms. Once with 300ms.",
1222 "When both timers are complete, describe the outputs.",
1223 ],
1224 cx,
1225 )
1226 })
1227 .unwrap()
1228 .collect()
1229 .await;
1230
1231 let stop_reasons = stop_events(events);
1232 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1233
1234 thread.update(cx, |thread, _cx| {
1235 let last_message = thread.last_received_or_pending_message().unwrap();
1236 let agent_message = last_message.as_agent_message().unwrap();
1237 let text = agent_message
1238 .content
1239 .iter()
1240 .filter_map(|content| {
1241 if let AgentMessageContent::Text(text) = content {
1242 Some(text.as_str())
1243 } else {
1244 None
1245 }
1246 })
1247 .collect::<String>();
1248
1249 assert!(text.contains("Ding"));
1250 });
1251}
1252
1253#[gpui::test]
1254async fn test_profiles(cx: &mut TestAppContext) {
1255 let ThreadTest {
1256 model, thread, fs, ..
1257 } = setup(cx, TestModel::Fake).await;
1258 let fake_model = model.as_fake();
1259
1260 thread.update(cx, |thread, _cx| {
1261 thread.add_tool(DelayTool);
1262 thread.add_tool(EchoTool);
1263 thread.add_tool(InfiniteTool);
1264 });
1265
1266 // Override profiles and wait for settings to be loaded.
1267 fs.insert_file(
1268 paths::settings_file(),
1269 json!({
1270 "agent": {
1271 "profiles": {
1272 "test-1": {
1273 "name": "Test Profile 1",
1274 "tools": {
1275 EchoTool::NAME: true,
1276 DelayTool::NAME: true,
1277 }
1278 },
1279 "test-2": {
1280 "name": "Test Profile 2",
1281 "tools": {
1282 InfiniteTool::NAME: true,
1283 }
1284 }
1285 }
1286 }
1287 })
1288 .to_string()
1289 .into_bytes(),
1290 )
1291 .await;
1292 cx.run_until_parked();
1293
1294 // Test that test-1 profile (default) has echo and delay tools
1295 thread
1296 .update(cx, |thread, cx| {
1297 thread.set_profile(AgentProfileId("test-1".into()), cx);
1298 thread.send(UserMessageId::new(), ["test"], cx)
1299 })
1300 .unwrap();
1301 cx.run_until_parked();
1302
1303 let mut pending_completions = fake_model.pending_completions();
1304 assert_eq!(pending_completions.len(), 1);
1305 let completion = pending_completions.pop().unwrap();
1306 let tool_names: Vec<String> = completion
1307 .tools
1308 .iter()
1309 .map(|tool| tool.name.clone())
1310 .collect();
1311 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1312 fake_model.end_last_completion_stream();
1313
1314 // Switch to test-2 profile, and verify that it has only the infinite tool.
1315 thread
1316 .update(cx, |thread, cx| {
1317 thread.set_profile(AgentProfileId("test-2".into()), cx);
1318 thread.send(UserMessageId::new(), ["test2"], cx)
1319 })
1320 .unwrap();
1321 cx.run_until_parked();
1322 let mut pending_completions = fake_model.pending_completions();
1323 assert_eq!(pending_completions.len(), 1);
1324 let completion = pending_completions.pop().unwrap();
1325 let tool_names: Vec<String> = completion
1326 .tools
1327 .iter()
1328 .map(|tool| tool.name.clone())
1329 .collect();
1330 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1331}
1332
1333#[gpui::test]
1334async fn test_mcp_tools(cx: &mut TestAppContext) {
1335 let ThreadTest {
1336 model,
1337 thread,
1338 context_server_store,
1339 fs,
1340 ..
1341 } = setup(cx, TestModel::Fake).await;
1342 let fake_model = model.as_fake();
1343
1344 // Override profiles and wait for settings to be loaded.
1345 fs.insert_file(
1346 paths::settings_file(),
1347 json!({
1348 "agent": {
1349 "tool_permissions": { "default": "allow" },
1350 "profiles": {
1351 "test": {
1352 "name": "Test Profile",
1353 "enable_all_context_servers": true,
1354 "tools": {
1355 EchoTool::NAME: true,
1356 }
1357 },
1358 }
1359 }
1360 })
1361 .to_string()
1362 .into_bytes(),
1363 )
1364 .await;
1365 cx.run_until_parked();
1366 thread.update(cx, |thread, cx| {
1367 thread.set_profile(AgentProfileId("test".into()), cx)
1368 });
1369
1370 let mut mcp_tool_calls = setup_context_server(
1371 "test_server",
1372 vec![context_server::types::Tool {
1373 name: "echo".into(),
1374 description: None,
1375 input_schema: serde_json::to_value(EchoTool::input_schema(
1376 LanguageModelToolSchemaFormat::JsonSchema,
1377 ))
1378 .unwrap(),
1379 output_schema: None,
1380 annotations: None,
1381 }],
1382 &context_server_store,
1383 cx,
1384 );
1385
1386 let events = thread.update(cx, |thread, cx| {
1387 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1388 });
1389 cx.run_until_parked();
1390
1391 // Simulate the model calling the MCP tool.
1392 let completion = fake_model.pending_completions().pop().unwrap();
1393 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1394 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1395 LanguageModelToolUse {
1396 id: "tool_1".into(),
1397 name: "echo".into(),
1398 raw_input: json!({"text": "test"}).to_string(),
1399 input: json!({"text": "test"}),
1400 is_input_complete: true,
1401 thought_signature: None,
1402 },
1403 ));
1404 fake_model.end_last_completion_stream();
1405 cx.run_until_parked();
1406
1407 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1408 assert_eq!(tool_call_params.name, "echo");
1409 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1410 tool_call_response
1411 .send(context_server::types::CallToolResponse {
1412 content: vec![context_server::types::ToolResponseContent::Text {
1413 text: "test".into(),
1414 }],
1415 is_error: None,
1416 meta: None,
1417 structured_content: None,
1418 })
1419 .unwrap();
1420 cx.run_until_parked();
1421
1422 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1423 fake_model.send_last_completion_stream_text_chunk("Done!");
1424 fake_model.end_last_completion_stream();
1425 events.collect::<Vec<_>>().await;
1426
1427 // Send again after adding the echo tool, ensuring the name collision is resolved.
1428 let events = thread.update(cx, |thread, cx| {
1429 thread.add_tool(EchoTool);
1430 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1431 });
1432 cx.run_until_parked();
1433 let completion = fake_model.pending_completions().pop().unwrap();
1434 assert_eq!(
1435 tool_names_for_completion(&completion),
1436 vec!["echo", "test_server_echo"]
1437 );
1438 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1439 LanguageModelToolUse {
1440 id: "tool_2".into(),
1441 name: "test_server_echo".into(),
1442 raw_input: json!({"text": "mcp"}).to_string(),
1443 input: json!({"text": "mcp"}),
1444 is_input_complete: true,
1445 thought_signature: None,
1446 },
1447 ));
1448 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1449 LanguageModelToolUse {
1450 id: "tool_3".into(),
1451 name: "echo".into(),
1452 raw_input: json!({"text": "native"}).to_string(),
1453 input: json!({"text": "native"}),
1454 is_input_complete: true,
1455 thought_signature: None,
1456 },
1457 ));
1458 fake_model.end_last_completion_stream();
1459 cx.run_until_parked();
1460
1461 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1462 assert_eq!(tool_call_params.name, "echo");
1463 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1464 tool_call_response
1465 .send(context_server::types::CallToolResponse {
1466 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1467 is_error: None,
1468 meta: None,
1469 structured_content: None,
1470 })
1471 .unwrap();
1472 cx.run_until_parked();
1473
1474 // Ensure the tool results were inserted with the correct names.
1475 let completion = fake_model.pending_completions().pop().unwrap();
1476 assert_eq!(
1477 completion.messages.last().unwrap().content,
1478 vec![
1479 MessageContent::ToolResult(LanguageModelToolResult {
1480 tool_use_id: "tool_3".into(),
1481 tool_name: "echo".into(),
1482 is_error: false,
1483 content: "native".into(),
1484 output: Some("native".into()),
1485 },),
1486 MessageContent::ToolResult(LanguageModelToolResult {
1487 tool_use_id: "tool_2".into(),
1488 tool_name: "test_server_echo".into(),
1489 is_error: false,
1490 content: "mcp".into(),
1491 output: Some("mcp".into()),
1492 },),
1493 ]
1494 );
1495 fake_model.end_last_completion_stream();
1496 events.collect::<Vec<_>>().await;
1497}
1498
1499#[gpui::test]
1500async fn test_mcp_tool_result_displayed_when_server_disconnected(cx: &mut TestAppContext) {
1501 let ThreadTest {
1502 model,
1503 thread,
1504 context_server_store,
1505 fs,
1506 ..
1507 } = setup(cx, TestModel::Fake).await;
1508 let fake_model = model.as_fake();
1509
1510 // Setup settings to allow MCP tools
1511 fs.insert_file(
1512 paths::settings_file(),
1513 json!({
1514 "agent": {
1515 "always_allow_tool_actions": true,
1516 "profiles": {
1517 "test": {
1518 "name": "Test Profile",
1519 "enable_all_context_servers": true,
1520 "tools": {}
1521 },
1522 }
1523 }
1524 })
1525 .to_string()
1526 .into_bytes(),
1527 )
1528 .await;
1529 cx.run_until_parked();
1530 thread.update(cx, |thread, cx| {
1531 thread.set_profile(AgentProfileId("test".into()), cx)
1532 });
1533
1534 // Setup a context server with a tool
1535 let mut mcp_tool_calls = setup_context_server(
1536 "github_server",
1537 vec![context_server::types::Tool {
1538 name: "issue_read".into(),
1539 description: Some("Read a GitHub issue".into()),
1540 input_schema: json!({
1541 "type": "object",
1542 "properties": {
1543 "issue_url": { "type": "string" }
1544 }
1545 }),
1546 output_schema: None,
1547 annotations: None,
1548 }],
1549 &context_server_store,
1550 cx,
1551 );
1552
1553 // Send a message and have the model call the MCP tool
1554 let events = thread.update(cx, |thread, cx| {
1555 thread
1556 .send(UserMessageId::new(), ["Read issue #47404"], cx)
1557 .unwrap()
1558 });
1559 cx.run_until_parked();
1560
1561 // Verify the MCP tool is available to the model
1562 let completion = fake_model.pending_completions().pop().unwrap();
1563 assert_eq!(
1564 tool_names_for_completion(&completion),
1565 vec!["issue_read"],
1566 "MCP tool should be available"
1567 );
1568
1569 // Simulate the model calling the MCP tool
1570 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1571 LanguageModelToolUse {
1572 id: "tool_1".into(),
1573 name: "issue_read".into(),
1574 raw_input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"})
1575 .to_string(),
1576 input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"}),
1577 is_input_complete: true,
1578 thought_signature: None,
1579 },
1580 ));
1581 fake_model.end_last_completion_stream();
1582 cx.run_until_parked();
1583
1584 // The MCP server receives the tool call and responds with content
1585 let expected_tool_output = "Issue #47404: Tool call results are cleared upon app close";
1586 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1587 assert_eq!(tool_call_params.name, "issue_read");
1588 tool_call_response
1589 .send(context_server::types::CallToolResponse {
1590 content: vec![context_server::types::ToolResponseContent::Text {
1591 text: expected_tool_output.into(),
1592 }],
1593 is_error: None,
1594 meta: None,
1595 structured_content: None,
1596 })
1597 .unwrap();
1598 cx.run_until_parked();
1599
1600 // After tool completes, the model continues with a new completion request
1601 // that includes the tool results. We need to respond to this.
1602 let _completion = fake_model.pending_completions().pop().unwrap();
1603 fake_model.send_last_completion_stream_text_chunk("I found the issue!");
1604 fake_model
1605 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1606 fake_model.end_last_completion_stream();
1607 events.collect::<Vec<_>>().await;
1608
1609 // Verify the tool result is stored in the thread by checking the markdown output.
1610 // The tool result is in the first assistant message (not the last one, which is
1611 // the model's response after the tool completed).
1612 thread.update(cx, |thread, _cx| {
1613 let markdown = thread.to_markdown();
1614 assert!(
1615 markdown.contains("**Tool Result**: issue_read"),
1616 "Thread should contain tool result header"
1617 );
1618 assert!(
1619 markdown.contains(expected_tool_output),
1620 "Thread should contain tool output: {}",
1621 expected_tool_output
1622 );
1623 });
1624
1625 // Simulate app restart: disconnect the MCP server.
1626 // After restart, the MCP server won't be connected yet when the thread is replayed.
1627 context_server_store.update(cx, |store, cx| {
1628 let _ = store.stop_server(&ContextServerId("github_server".into()), cx);
1629 });
1630 cx.run_until_parked();
1631
1632 // Replay the thread (this is what happens when loading a saved thread)
1633 let mut replay_events = thread.update(cx, |thread, cx| thread.replay(cx));
1634
1635 let mut found_tool_call = None;
1636 let mut found_tool_call_update_with_output = None;
1637
1638 while let Some(event) = replay_events.next().await {
1639 let event = event.unwrap();
1640 match &event {
1641 ThreadEvent::ToolCall(tc) if tc.tool_call_id.to_string() == "tool_1" => {
1642 found_tool_call = Some(tc.clone());
1643 }
1644 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update))
1645 if update.tool_call_id.to_string() == "tool_1" =>
1646 {
1647 if update.fields.raw_output.is_some() {
1648 found_tool_call_update_with_output = Some(update.clone());
1649 }
1650 }
1651 _ => {}
1652 }
1653 }
1654
1655 // The tool call should be found
1656 assert!(
1657 found_tool_call.is_some(),
1658 "Tool call should be emitted during replay"
1659 );
1660
1661 assert!(
1662 found_tool_call_update_with_output.is_some(),
1663 "ToolCallUpdate with raw_output should be emitted even when MCP server is disconnected."
1664 );
1665
1666 let update = found_tool_call_update_with_output.unwrap();
1667 assert_eq!(
1668 update.fields.raw_output,
1669 Some(expected_tool_output.into()),
1670 "raw_output should contain the saved tool result"
1671 );
1672
1673 // Also verify the status is correct (completed, not failed)
1674 assert_eq!(
1675 update.fields.status,
1676 Some(acp::ToolCallStatus::Completed),
1677 "Tool call status should reflect the original completion status"
1678 );
1679}
1680
1681#[gpui::test]
1682async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1683 let ThreadTest {
1684 model,
1685 thread,
1686 context_server_store,
1687 fs,
1688 ..
1689 } = setup(cx, TestModel::Fake).await;
1690 let fake_model = model.as_fake();
1691
1692 // Set up a profile with all tools enabled
1693 fs.insert_file(
1694 paths::settings_file(),
1695 json!({
1696 "agent": {
1697 "profiles": {
1698 "test": {
1699 "name": "Test Profile",
1700 "enable_all_context_servers": true,
1701 "tools": {
1702 EchoTool::NAME: true,
1703 DelayTool::NAME: true,
1704 WordListTool::NAME: true,
1705 ToolRequiringPermission::NAME: true,
1706 InfiniteTool::NAME: true,
1707 }
1708 },
1709 }
1710 }
1711 })
1712 .to_string()
1713 .into_bytes(),
1714 )
1715 .await;
1716 cx.run_until_parked();
1717
1718 thread.update(cx, |thread, cx| {
1719 thread.set_profile(AgentProfileId("test".into()), cx);
1720 thread.add_tool(EchoTool);
1721 thread.add_tool(DelayTool);
1722 thread.add_tool(WordListTool);
1723 thread.add_tool(ToolRequiringPermission);
1724 thread.add_tool(InfiniteTool);
1725 });
1726
1727 // Set up multiple context servers with some overlapping tool names
1728 let _server1_calls = setup_context_server(
1729 "xxx",
1730 vec![
1731 context_server::types::Tool {
1732 name: "echo".into(), // Conflicts with native EchoTool
1733 description: None,
1734 input_schema: serde_json::to_value(EchoTool::input_schema(
1735 LanguageModelToolSchemaFormat::JsonSchema,
1736 ))
1737 .unwrap(),
1738 output_schema: None,
1739 annotations: None,
1740 },
1741 context_server::types::Tool {
1742 name: "unique_tool_1".into(),
1743 description: None,
1744 input_schema: json!({"type": "object", "properties": {}}),
1745 output_schema: None,
1746 annotations: None,
1747 },
1748 ],
1749 &context_server_store,
1750 cx,
1751 );
1752
1753 let _server2_calls = setup_context_server(
1754 "yyy",
1755 vec![
1756 context_server::types::Tool {
1757 name: "echo".into(), // Also conflicts with native EchoTool
1758 description: None,
1759 input_schema: serde_json::to_value(EchoTool::input_schema(
1760 LanguageModelToolSchemaFormat::JsonSchema,
1761 ))
1762 .unwrap(),
1763 output_schema: None,
1764 annotations: None,
1765 },
1766 context_server::types::Tool {
1767 name: "unique_tool_2".into(),
1768 description: None,
1769 input_schema: json!({"type": "object", "properties": {}}),
1770 output_schema: None,
1771 annotations: None,
1772 },
1773 context_server::types::Tool {
1774 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1775 description: None,
1776 input_schema: json!({"type": "object", "properties": {}}),
1777 output_schema: None,
1778 annotations: None,
1779 },
1780 context_server::types::Tool {
1781 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1782 description: None,
1783 input_schema: json!({"type": "object", "properties": {}}),
1784 output_schema: None,
1785 annotations: None,
1786 },
1787 ],
1788 &context_server_store,
1789 cx,
1790 );
1791 let _server3_calls = setup_context_server(
1792 "zzz",
1793 vec![
1794 context_server::types::Tool {
1795 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1796 description: None,
1797 input_schema: json!({"type": "object", "properties": {}}),
1798 output_schema: None,
1799 annotations: None,
1800 },
1801 context_server::types::Tool {
1802 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1803 description: None,
1804 input_schema: json!({"type": "object", "properties": {}}),
1805 output_schema: None,
1806 annotations: None,
1807 },
1808 context_server::types::Tool {
1809 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1810 description: None,
1811 input_schema: json!({"type": "object", "properties": {}}),
1812 output_schema: None,
1813 annotations: None,
1814 },
1815 ],
1816 &context_server_store,
1817 cx,
1818 );
1819
1820 // Server with spaces in name - tests snake_case conversion for API compatibility
1821 let _server4_calls = setup_context_server(
1822 "Azure DevOps",
1823 vec![context_server::types::Tool {
1824 name: "echo".into(), // Also conflicts - will be disambiguated as azure_dev_ops_echo
1825 description: None,
1826 input_schema: serde_json::to_value(EchoTool::input_schema(
1827 LanguageModelToolSchemaFormat::JsonSchema,
1828 ))
1829 .unwrap(),
1830 output_schema: None,
1831 annotations: None,
1832 }],
1833 &context_server_store,
1834 cx,
1835 );
1836
1837 thread
1838 .update(cx, |thread, cx| {
1839 thread.send(UserMessageId::new(), ["Go"], cx)
1840 })
1841 .unwrap();
1842 cx.run_until_parked();
1843 let completion = fake_model.pending_completions().pop().unwrap();
1844 assert_eq!(
1845 tool_names_for_completion(&completion),
1846 vec![
1847 "azure_dev_ops_echo",
1848 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1849 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1850 "delay",
1851 "echo",
1852 "infinite",
1853 "tool_requiring_permission",
1854 "unique_tool_1",
1855 "unique_tool_2",
1856 "word_list",
1857 "xxx_echo",
1858 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1859 "yyy_echo",
1860 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1861 ]
1862 );
1863}
1864
1865#[gpui::test]
1866#[cfg_attr(not(feature = "e2e"), ignore)]
1867async fn test_cancellation(cx: &mut TestAppContext) {
1868 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1869
1870 let mut events = thread
1871 .update(cx, |thread, cx| {
1872 thread.add_tool(InfiniteTool);
1873 thread.add_tool(EchoTool);
1874 thread.send(
1875 UserMessageId::new(),
1876 ["Call the echo tool, then call the infinite tool, then explain their output"],
1877 cx,
1878 )
1879 })
1880 .unwrap();
1881
1882 // Wait until both tools are called.
1883 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1884 let mut echo_id = None;
1885 let mut echo_completed = false;
1886 while let Some(event) = events.next().await {
1887 match event.unwrap() {
1888 ThreadEvent::ToolCall(tool_call) => {
1889 assert_eq!(tool_call.title, expected_tools.remove(0));
1890 if tool_call.title == "Echo" {
1891 echo_id = Some(tool_call.tool_call_id);
1892 }
1893 }
1894 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1895 acp::ToolCallUpdate {
1896 tool_call_id,
1897 fields:
1898 acp::ToolCallUpdateFields {
1899 status: Some(acp::ToolCallStatus::Completed),
1900 ..
1901 },
1902 ..
1903 },
1904 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1905 echo_completed = true;
1906 }
1907 _ => {}
1908 }
1909
1910 if expected_tools.is_empty() && echo_completed {
1911 break;
1912 }
1913 }
1914
1915 // Cancel the current send and ensure that the event stream is closed, even
1916 // if one of the tools is still running.
1917 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1918 let events = events.collect::<Vec<_>>().await;
1919 let last_event = events.last();
1920 assert!(
1921 matches!(
1922 last_event,
1923 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1924 ),
1925 "unexpected event {last_event:?}"
1926 );
1927
1928 // Ensure we can still send a new message after cancellation.
1929 let events = thread
1930 .update(cx, |thread, cx| {
1931 thread.send(
1932 UserMessageId::new(),
1933 ["Testing: reply with 'Hello' then stop."],
1934 cx,
1935 )
1936 })
1937 .unwrap()
1938 .collect::<Vec<_>>()
1939 .await;
1940 thread.update(cx, |thread, _cx| {
1941 let message = thread.last_received_or_pending_message().unwrap();
1942 let agent_message = message.as_agent_message().unwrap();
1943 assert_eq!(
1944 agent_message.content,
1945 vec![AgentMessageContent::Text("Hello".to_string())]
1946 );
1947 });
1948 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1949}
1950
1951#[gpui::test]
1952async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1953 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1954 always_allow_tools(cx);
1955 let fake_model = model.as_fake();
1956
1957 let environment = Rc::new(cx.update(|cx| {
1958 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1959 }));
1960 let handle = environment.terminal_handle.clone().unwrap();
1961
1962 let mut events = thread
1963 .update(cx, |thread, cx| {
1964 thread.add_tool(crate::TerminalTool::new(
1965 thread.project().clone(),
1966 environment,
1967 ));
1968 thread.send(UserMessageId::new(), ["run a command"], cx)
1969 })
1970 .unwrap();
1971
1972 cx.run_until_parked();
1973
1974 // Simulate the model calling the terminal tool
1975 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1976 LanguageModelToolUse {
1977 id: "terminal_tool_1".into(),
1978 name: TerminalTool::NAME.into(),
1979 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1980 input: json!({"command": "sleep 1000", "cd": "."}),
1981 is_input_complete: true,
1982 thought_signature: None,
1983 },
1984 ));
1985 fake_model.end_last_completion_stream();
1986
1987 // Wait for the terminal tool to start running
1988 wait_for_terminal_tool_started(&mut events, cx).await;
1989
1990 // Cancel the thread while the terminal is running
1991 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1992
1993 // Collect remaining events, driving the executor to let cancellation complete
1994 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1995
1996 // Verify the terminal was killed
1997 assert!(
1998 handle.was_killed(),
1999 "expected terminal handle to be killed on cancellation"
2000 );
2001
2002 // Verify we got a cancellation stop event
2003 assert_eq!(
2004 stop_events(remaining_events),
2005 vec![acp::StopReason::Cancelled],
2006 );
2007
2008 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
2009 thread.update(cx, |thread, _cx| {
2010 let message = thread.last_received_or_pending_message().unwrap();
2011 let agent_message = message.as_agent_message().unwrap();
2012
2013 let tool_use = agent_message
2014 .content
2015 .iter()
2016 .find_map(|content| match content {
2017 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2018 _ => None,
2019 })
2020 .expect("expected tool use in agent message");
2021
2022 let tool_result = agent_message
2023 .tool_results
2024 .get(&tool_use.id)
2025 .expect("expected tool result");
2026
2027 let result_text = match &tool_result.content {
2028 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2029 _ => panic!("expected text content in tool result"),
2030 };
2031
2032 // "partial output" comes from FakeTerminalHandle's output field
2033 assert!(
2034 result_text.contains("partial output"),
2035 "expected tool result to contain terminal output, got: {result_text}"
2036 );
2037 // Match the actual format from process_content in terminal_tool.rs
2038 assert!(
2039 result_text.contains("The user stopped this command"),
2040 "expected tool result to indicate user stopped, got: {result_text}"
2041 );
2042 });
2043
2044 // Verify we can send a new message after cancellation
2045 verify_thread_recovery(&thread, &fake_model, cx).await;
2046}
2047
2048#[gpui::test]
2049async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
2050 // This test verifies that tools which properly handle cancellation via
2051 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
2052 // to cancellation and report that they were cancelled.
2053 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2054 always_allow_tools(cx);
2055 let fake_model = model.as_fake();
2056
2057 let (tool, was_cancelled) = CancellationAwareTool::new();
2058
2059 let mut events = thread
2060 .update(cx, |thread, cx| {
2061 thread.add_tool(tool);
2062 thread.send(
2063 UserMessageId::new(),
2064 ["call the cancellation aware tool"],
2065 cx,
2066 )
2067 })
2068 .unwrap();
2069
2070 cx.run_until_parked();
2071
2072 // Simulate the model calling the cancellation-aware tool
2073 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2074 LanguageModelToolUse {
2075 id: "cancellation_aware_1".into(),
2076 name: "cancellation_aware".into(),
2077 raw_input: r#"{}"#.into(),
2078 input: json!({}),
2079 is_input_complete: true,
2080 thought_signature: None,
2081 },
2082 ));
2083 fake_model.end_last_completion_stream();
2084
2085 cx.run_until_parked();
2086
2087 // Wait for the tool call to be reported
2088 let mut tool_started = false;
2089 let deadline = cx.executor().num_cpus() * 100;
2090 for _ in 0..deadline {
2091 cx.run_until_parked();
2092
2093 while let Some(Some(event)) = events.next().now_or_never() {
2094 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
2095 if tool_call.title == "Cancellation Aware Tool" {
2096 tool_started = true;
2097 break;
2098 }
2099 }
2100 }
2101
2102 if tool_started {
2103 break;
2104 }
2105
2106 cx.background_executor
2107 .timer(Duration::from_millis(10))
2108 .await;
2109 }
2110 assert!(tool_started, "expected cancellation aware tool to start");
2111
2112 // Cancel the thread and wait for it to complete
2113 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
2114
2115 // The cancel task should complete promptly because the tool handles cancellation
2116 let timeout = cx.background_executor.timer(Duration::from_secs(5));
2117 futures::select! {
2118 _ = cancel_task.fuse() => {}
2119 _ = timeout.fuse() => {
2120 panic!("cancel task timed out - tool did not respond to cancellation");
2121 }
2122 }
2123
2124 // Verify the tool detected cancellation via its flag
2125 assert!(
2126 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
2127 "tool should have detected cancellation via event_stream.cancelled_by_user()"
2128 );
2129
2130 // Collect remaining events
2131 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2132
2133 // Verify we got a cancellation stop event
2134 assert_eq!(
2135 stop_events(remaining_events),
2136 vec![acp::StopReason::Cancelled],
2137 );
2138
2139 // Verify we can send a new message after cancellation
2140 verify_thread_recovery(&thread, &fake_model, cx).await;
2141}
2142
2143/// Helper to verify thread can recover after cancellation by sending a simple message.
2144async fn verify_thread_recovery(
2145 thread: &Entity<Thread>,
2146 fake_model: &FakeLanguageModel,
2147 cx: &mut TestAppContext,
2148) {
2149 let events = thread
2150 .update(cx, |thread, cx| {
2151 thread.send(
2152 UserMessageId::new(),
2153 ["Testing: reply with 'Hello' then stop."],
2154 cx,
2155 )
2156 })
2157 .unwrap();
2158 cx.run_until_parked();
2159 fake_model.send_last_completion_stream_text_chunk("Hello");
2160 fake_model
2161 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2162 fake_model.end_last_completion_stream();
2163
2164 let events = events.collect::<Vec<_>>().await;
2165 thread.update(cx, |thread, _cx| {
2166 let message = thread.last_received_or_pending_message().unwrap();
2167 let agent_message = message.as_agent_message().unwrap();
2168 assert_eq!(
2169 agent_message.content,
2170 vec![AgentMessageContent::Text("Hello".to_string())]
2171 );
2172 });
2173 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
2174}
2175
2176/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
2177async fn wait_for_terminal_tool_started(
2178 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2179 cx: &mut TestAppContext,
2180) {
2181 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
2182 for _ in 0..deadline {
2183 cx.run_until_parked();
2184
2185 while let Some(Some(event)) = events.next().now_or_never() {
2186 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2187 update,
2188 ))) = &event
2189 {
2190 if update.fields.content.as_ref().is_some_and(|content| {
2191 content
2192 .iter()
2193 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2194 }) {
2195 return;
2196 }
2197 }
2198 }
2199
2200 cx.background_executor
2201 .timer(Duration::from_millis(10))
2202 .await;
2203 }
2204 panic!("terminal tool did not start within the expected time");
2205}
2206
2207/// Collects events until a Stop event is received, driving the executor to completion.
2208async fn collect_events_until_stop(
2209 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2210 cx: &mut TestAppContext,
2211) -> Vec<Result<ThreadEvent>> {
2212 let mut collected = Vec::new();
2213 let deadline = cx.executor().num_cpus() * 200;
2214
2215 for _ in 0..deadline {
2216 cx.executor().advance_clock(Duration::from_millis(10));
2217 cx.run_until_parked();
2218
2219 while let Some(Some(event)) = events.next().now_or_never() {
2220 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
2221 collected.push(event);
2222 if is_stop {
2223 return collected;
2224 }
2225 }
2226 }
2227 panic!(
2228 "did not receive Stop event within the expected time; collected {} events",
2229 collected.len()
2230 );
2231}
2232
2233#[gpui::test]
2234async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
2235 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2236 always_allow_tools(cx);
2237 let fake_model = model.as_fake();
2238
2239 let environment = Rc::new(cx.update(|cx| {
2240 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2241 }));
2242 let handle = environment.terminal_handle.clone().unwrap();
2243
2244 let message_id = UserMessageId::new();
2245 let mut events = thread
2246 .update(cx, |thread, cx| {
2247 thread.add_tool(crate::TerminalTool::new(
2248 thread.project().clone(),
2249 environment,
2250 ));
2251 thread.send(message_id.clone(), ["run a command"], cx)
2252 })
2253 .unwrap();
2254
2255 cx.run_until_parked();
2256
2257 // Simulate the model calling the terminal tool
2258 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2259 LanguageModelToolUse {
2260 id: "terminal_tool_1".into(),
2261 name: TerminalTool::NAME.into(),
2262 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2263 input: json!({"command": "sleep 1000", "cd": "."}),
2264 is_input_complete: true,
2265 thought_signature: None,
2266 },
2267 ));
2268 fake_model.end_last_completion_stream();
2269
2270 // Wait for the terminal tool to start running
2271 wait_for_terminal_tool_started(&mut events, cx).await;
2272
2273 // Truncate the thread while the terminal is running
2274 thread
2275 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2276 .unwrap();
2277
2278 // Drive the executor to let cancellation complete
2279 let _ = collect_events_until_stop(&mut events, cx).await;
2280
2281 // Verify the terminal was killed
2282 assert!(
2283 handle.was_killed(),
2284 "expected terminal handle to be killed on truncate"
2285 );
2286
2287 // Verify the thread is empty after truncation
2288 thread.update(cx, |thread, _cx| {
2289 assert_eq!(
2290 thread.to_markdown(),
2291 "",
2292 "expected thread to be empty after truncating the only message"
2293 );
2294 });
2295
2296 // Verify we can send a new message after truncation
2297 verify_thread_recovery(&thread, &fake_model, cx).await;
2298}
2299
2300#[gpui::test]
2301async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2302 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2303 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2304 always_allow_tools(cx);
2305 let fake_model = model.as_fake();
2306
2307 let environment = Rc::new(MultiTerminalEnvironment::new());
2308
2309 let mut events = thread
2310 .update(cx, |thread, cx| {
2311 thread.add_tool(crate::TerminalTool::new(
2312 thread.project().clone(),
2313 environment.clone(),
2314 ));
2315 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2316 })
2317 .unwrap();
2318
2319 cx.run_until_parked();
2320
2321 // Simulate the model calling two terminal tools
2322 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2323 LanguageModelToolUse {
2324 id: "terminal_tool_1".into(),
2325 name: TerminalTool::NAME.into(),
2326 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2327 input: json!({"command": "sleep 1000", "cd": "."}),
2328 is_input_complete: true,
2329 thought_signature: None,
2330 },
2331 ));
2332 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2333 LanguageModelToolUse {
2334 id: "terminal_tool_2".into(),
2335 name: TerminalTool::NAME.into(),
2336 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2337 input: json!({"command": "sleep 2000", "cd": "."}),
2338 is_input_complete: true,
2339 thought_signature: None,
2340 },
2341 ));
2342 fake_model.end_last_completion_stream();
2343
2344 // Wait for both terminal tools to start by counting terminal content updates
2345 let mut terminals_started = 0;
2346 let deadline = cx.executor().num_cpus() * 100;
2347 for _ in 0..deadline {
2348 cx.run_until_parked();
2349
2350 while let Some(Some(event)) = events.next().now_or_never() {
2351 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2352 update,
2353 ))) = &event
2354 {
2355 if update.fields.content.as_ref().is_some_and(|content| {
2356 content
2357 .iter()
2358 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2359 }) {
2360 terminals_started += 1;
2361 if terminals_started >= 2 {
2362 break;
2363 }
2364 }
2365 }
2366 }
2367 if terminals_started >= 2 {
2368 break;
2369 }
2370
2371 cx.background_executor
2372 .timer(Duration::from_millis(10))
2373 .await;
2374 }
2375 assert!(
2376 terminals_started >= 2,
2377 "expected 2 terminal tools to start, got {terminals_started}"
2378 );
2379
2380 // Cancel the thread while both terminals are running
2381 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2382
2383 // Collect remaining events
2384 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2385
2386 // Verify both terminal handles were killed
2387 let handles = environment.handles();
2388 assert_eq!(
2389 handles.len(),
2390 2,
2391 "expected 2 terminal handles to be created"
2392 );
2393 assert!(
2394 handles[0].was_killed(),
2395 "expected first terminal handle to be killed on cancellation"
2396 );
2397 assert!(
2398 handles[1].was_killed(),
2399 "expected second terminal handle to be killed on cancellation"
2400 );
2401
2402 // Verify we got a cancellation stop event
2403 assert_eq!(
2404 stop_events(remaining_events),
2405 vec![acp::StopReason::Cancelled],
2406 );
2407}
2408
2409#[gpui::test]
2410async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2411 // Tests that clicking the stop button on the terminal card (as opposed to the main
2412 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2413 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2414 always_allow_tools(cx);
2415 let fake_model = model.as_fake();
2416
2417 let environment = Rc::new(cx.update(|cx| {
2418 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2419 }));
2420 let handle = environment.terminal_handle.clone().unwrap();
2421
2422 let mut events = thread
2423 .update(cx, |thread, cx| {
2424 thread.add_tool(crate::TerminalTool::new(
2425 thread.project().clone(),
2426 environment,
2427 ));
2428 thread.send(UserMessageId::new(), ["run a command"], cx)
2429 })
2430 .unwrap();
2431
2432 cx.run_until_parked();
2433
2434 // Simulate the model calling the terminal tool
2435 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2436 LanguageModelToolUse {
2437 id: "terminal_tool_1".into(),
2438 name: TerminalTool::NAME.into(),
2439 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2440 input: json!({"command": "sleep 1000", "cd": "."}),
2441 is_input_complete: true,
2442 thought_signature: None,
2443 },
2444 ));
2445 fake_model.end_last_completion_stream();
2446
2447 // Wait for the terminal tool to start running
2448 wait_for_terminal_tool_started(&mut events, cx).await;
2449
2450 // Simulate user clicking stop on the terminal card itself.
2451 // This sets the flag and signals exit (simulating what the real UI would do).
2452 handle.set_stopped_by_user(true);
2453 handle.killed.store(true, Ordering::SeqCst);
2454 handle.signal_exit();
2455
2456 // Wait for the tool to complete
2457 cx.run_until_parked();
2458
2459 // The thread continues after tool completion - simulate the model ending its turn
2460 fake_model
2461 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2462 fake_model.end_last_completion_stream();
2463
2464 // Collect remaining events
2465 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2466
2467 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2468 assert_eq!(
2469 stop_events(remaining_events),
2470 vec![acp::StopReason::EndTurn],
2471 );
2472
2473 // Verify the tool result indicates user stopped
2474 thread.update(cx, |thread, _cx| {
2475 let message = thread.last_received_or_pending_message().unwrap();
2476 let agent_message = message.as_agent_message().unwrap();
2477
2478 let tool_use = agent_message
2479 .content
2480 .iter()
2481 .find_map(|content| match content {
2482 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2483 _ => None,
2484 })
2485 .expect("expected tool use in agent message");
2486
2487 let tool_result = agent_message
2488 .tool_results
2489 .get(&tool_use.id)
2490 .expect("expected tool result");
2491
2492 let result_text = match &tool_result.content {
2493 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2494 _ => panic!("expected text content in tool result"),
2495 };
2496
2497 assert!(
2498 result_text.contains("The user stopped this command"),
2499 "expected tool result to indicate user stopped, got: {result_text}"
2500 );
2501 });
2502}
2503
2504#[gpui::test]
2505async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2506 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2507 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2508 always_allow_tools(cx);
2509 let fake_model = model.as_fake();
2510
2511 let environment = Rc::new(cx.update(|cx| {
2512 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2513 }));
2514 let handle = environment.terminal_handle.clone().unwrap();
2515
2516 let mut events = thread
2517 .update(cx, |thread, cx| {
2518 thread.add_tool(crate::TerminalTool::new(
2519 thread.project().clone(),
2520 environment,
2521 ));
2522 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2523 })
2524 .unwrap();
2525
2526 cx.run_until_parked();
2527
2528 // Simulate the model calling the terminal tool with a short timeout
2529 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2530 LanguageModelToolUse {
2531 id: "terminal_tool_1".into(),
2532 name: TerminalTool::NAME.into(),
2533 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2534 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2535 is_input_complete: true,
2536 thought_signature: None,
2537 },
2538 ));
2539 fake_model.end_last_completion_stream();
2540
2541 // Wait for the terminal tool to start running
2542 wait_for_terminal_tool_started(&mut events, cx).await;
2543
2544 // Advance clock past the timeout
2545 cx.executor().advance_clock(Duration::from_millis(200));
2546 cx.run_until_parked();
2547
2548 // The thread continues after tool completion - simulate the model ending its turn
2549 fake_model
2550 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2551 fake_model.end_last_completion_stream();
2552
2553 // Collect remaining events
2554 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2555
2556 // Verify the terminal was killed due to timeout
2557 assert!(
2558 handle.was_killed(),
2559 "expected terminal handle to be killed on timeout"
2560 );
2561
2562 // Verify we got an EndTurn (the tool completed, just with timeout)
2563 assert_eq!(
2564 stop_events(remaining_events),
2565 vec![acp::StopReason::EndTurn],
2566 );
2567
2568 // Verify the tool result indicates timeout, not user stopped
2569 thread.update(cx, |thread, _cx| {
2570 let message = thread.last_received_or_pending_message().unwrap();
2571 let agent_message = message.as_agent_message().unwrap();
2572
2573 let tool_use = agent_message
2574 .content
2575 .iter()
2576 .find_map(|content| match content {
2577 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2578 _ => None,
2579 })
2580 .expect("expected tool use in agent message");
2581
2582 let tool_result = agent_message
2583 .tool_results
2584 .get(&tool_use.id)
2585 .expect("expected tool result");
2586
2587 let result_text = match &tool_result.content {
2588 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2589 _ => panic!("expected text content in tool result"),
2590 };
2591
2592 assert!(
2593 result_text.contains("timed out"),
2594 "expected tool result to indicate timeout, got: {result_text}"
2595 );
2596 assert!(
2597 !result_text.contains("The user stopped"),
2598 "tool result should not mention user stopped when it timed out, got: {result_text}"
2599 );
2600 });
2601}
2602
2603#[gpui::test]
2604async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2605 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2606 let fake_model = model.as_fake();
2607
2608 let events_1 = thread
2609 .update(cx, |thread, cx| {
2610 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2611 })
2612 .unwrap();
2613 cx.run_until_parked();
2614 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2615 cx.run_until_parked();
2616
2617 let events_2 = thread
2618 .update(cx, |thread, cx| {
2619 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2620 })
2621 .unwrap();
2622 cx.run_until_parked();
2623 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2624 fake_model
2625 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2626 fake_model.end_last_completion_stream();
2627
2628 let events_1 = events_1.collect::<Vec<_>>().await;
2629 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2630 let events_2 = events_2.collect::<Vec<_>>().await;
2631 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2632}
2633
2634#[gpui::test]
2635async fn test_retry_cancelled_promptly_on_new_send(cx: &mut TestAppContext) {
2636 // Regression test: when a completion fails with a retryable error (e.g. upstream 500),
2637 // the retry loop waits on a timer. If the user switches models and sends a new message
2638 // during that delay, the old turn should exit immediately instead of retrying with the
2639 // stale model.
2640 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2641 let model_a = model.as_fake();
2642
2643 // Start a turn with model_a.
2644 let events_1 = thread
2645 .update(cx, |thread, cx| {
2646 thread.send(UserMessageId::new(), ["Hello"], cx)
2647 })
2648 .unwrap();
2649 cx.run_until_parked();
2650 assert_eq!(model_a.completion_count(), 1);
2651
2652 // Model returns a retryable upstream 500. The turn enters the retry delay.
2653 model_a.send_last_completion_stream_error(
2654 LanguageModelCompletionError::UpstreamProviderError {
2655 message: "Internal server error".to_string(),
2656 status: http_client::StatusCode::INTERNAL_SERVER_ERROR,
2657 retry_after: None,
2658 },
2659 );
2660 model_a.end_last_completion_stream();
2661 cx.run_until_parked();
2662
2663 // The old completion was consumed; model_a has no pending requests yet because the
2664 // retry timer hasn't fired.
2665 assert_eq!(model_a.completion_count(), 0);
2666
2667 // Switch to model_b and send a new message. This cancels the old turn.
2668 let model_b = Arc::new(FakeLanguageModel::with_id_and_thinking(
2669 "fake", "model-b", "Model B", false,
2670 ));
2671 thread.update(cx, |thread, cx| {
2672 thread.set_model(model_b.clone(), cx);
2673 });
2674 let events_2 = thread
2675 .update(cx, |thread, cx| {
2676 thread.send(UserMessageId::new(), ["Continue"], cx)
2677 })
2678 .unwrap();
2679 cx.run_until_parked();
2680
2681 // model_b should have received its completion request.
2682 assert_eq!(model_b.as_fake().completion_count(), 1);
2683
2684 // Advance the clock well past the retry delay (BASE_RETRY_DELAY = 5s).
2685 cx.executor().advance_clock(Duration::from_secs(10));
2686 cx.run_until_parked();
2687
2688 // model_a must NOT have received another completion request — the cancelled turn
2689 // should have exited during the retry delay rather than retrying with the old model.
2690 assert_eq!(
2691 model_a.completion_count(),
2692 0,
2693 "old model should not receive a retry request after cancellation"
2694 );
2695
2696 // Complete model_b's turn.
2697 model_b
2698 .as_fake()
2699 .send_last_completion_stream_text_chunk("Done!");
2700 model_b
2701 .as_fake()
2702 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2703 model_b.as_fake().end_last_completion_stream();
2704
2705 let events_1 = events_1.collect::<Vec<_>>().await;
2706 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2707
2708 let events_2 = events_2.collect::<Vec<_>>().await;
2709 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2710}
2711
2712#[gpui::test]
2713async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2714 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2715 let fake_model = model.as_fake();
2716
2717 let events_1 = thread
2718 .update(cx, |thread, cx| {
2719 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2720 })
2721 .unwrap();
2722 cx.run_until_parked();
2723 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2724 fake_model
2725 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2726 fake_model.end_last_completion_stream();
2727 let events_1 = events_1.collect::<Vec<_>>().await;
2728
2729 let events_2 = thread
2730 .update(cx, |thread, cx| {
2731 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2732 })
2733 .unwrap();
2734 cx.run_until_parked();
2735 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2736 fake_model
2737 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2738 fake_model.end_last_completion_stream();
2739 let events_2 = events_2.collect::<Vec<_>>().await;
2740
2741 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2742 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2743}
2744
2745#[gpui::test]
2746async fn test_refusal(cx: &mut TestAppContext) {
2747 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2748 let fake_model = model.as_fake();
2749
2750 let events = thread
2751 .update(cx, |thread, cx| {
2752 thread.send(UserMessageId::new(), ["Hello"], cx)
2753 })
2754 .unwrap();
2755 cx.run_until_parked();
2756 thread.read_with(cx, |thread, _| {
2757 assert_eq!(
2758 thread.to_markdown(),
2759 indoc! {"
2760 ## User
2761
2762 Hello
2763 "}
2764 );
2765 });
2766
2767 fake_model.send_last_completion_stream_text_chunk("Hey!");
2768 cx.run_until_parked();
2769 thread.read_with(cx, |thread, _| {
2770 assert_eq!(
2771 thread.to_markdown(),
2772 indoc! {"
2773 ## User
2774
2775 Hello
2776
2777 ## Assistant
2778
2779 Hey!
2780 "}
2781 );
2782 });
2783
2784 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2785 fake_model
2786 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2787 let events = events.collect::<Vec<_>>().await;
2788 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2789 thread.read_with(cx, |thread, _| {
2790 assert_eq!(thread.to_markdown(), "");
2791 });
2792}
2793
2794#[gpui::test]
2795async fn test_truncate_first_message(cx: &mut TestAppContext) {
2796 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2797 let fake_model = model.as_fake();
2798
2799 let message_id = UserMessageId::new();
2800 thread
2801 .update(cx, |thread, cx| {
2802 thread.send(message_id.clone(), ["Hello"], cx)
2803 })
2804 .unwrap();
2805 cx.run_until_parked();
2806 thread.read_with(cx, |thread, _| {
2807 assert_eq!(
2808 thread.to_markdown(),
2809 indoc! {"
2810 ## User
2811
2812 Hello
2813 "}
2814 );
2815 assert_eq!(thread.latest_token_usage(), None);
2816 });
2817
2818 fake_model.send_last_completion_stream_text_chunk("Hey!");
2819 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2820 language_model::TokenUsage {
2821 input_tokens: 32_000,
2822 output_tokens: 16_000,
2823 cache_creation_input_tokens: 0,
2824 cache_read_input_tokens: 0,
2825 },
2826 ));
2827 cx.run_until_parked();
2828 thread.read_with(cx, |thread, _| {
2829 assert_eq!(
2830 thread.to_markdown(),
2831 indoc! {"
2832 ## User
2833
2834 Hello
2835
2836 ## Assistant
2837
2838 Hey!
2839 "}
2840 );
2841 assert_eq!(
2842 thread.latest_token_usage(),
2843 Some(acp_thread::TokenUsage {
2844 used_tokens: 32_000 + 16_000,
2845 max_tokens: 1_000_000,
2846 max_output_tokens: None,
2847 input_tokens: 32_000,
2848 output_tokens: 16_000,
2849 })
2850 );
2851 });
2852
2853 thread
2854 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2855 .unwrap();
2856 cx.run_until_parked();
2857 thread.read_with(cx, |thread, _| {
2858 assert_eq!(thread.to_markdown(), "");
2859 assert_eq!(thread.latest_token_usage(), None);
2860 });
2861
2862 // Ensure we can still send a new message after truncation.
2863 thread
2864 .update(cx, |thread, cx| {
2865 thread.send(UserMessageId::new(), ["Hi"], cx)
2866 })
2867 .unwrap();
2868 thread.update(cx, |thread, _cx| {
2869 assert_eq!(
2870 thread.to_markdown(),
2871 indoc! {"
2872 ## User
2873
2874 Hi
2875 "}
2876 );
2877 });
2878 cx.run_until_parked();
2879 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2880 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2881 language_model::TokenUsage {
2882 input_tokens: 40_000,
2883 output_tokens: 20_000,
2884 cache_creation_input_tokens: 0,
2885 cache_read_input_tokens: 0,
2886 },
2887 ));
2888 cx.run_until_parked();
2889 thread.read_with(cx, |thread, _| {
2890 assert_eq!(
2891 thread.to_markdown(),
2892 indoc! {"
2893 ## User
2894
2895 Hi
2896
2897 ## Assistant
2898
2899 Ahoy!
2900 "}
2901 );
2902
2903 assert_eq!(
2904 thread.latest_token_usage(),
2905 Some(acp_thread::TokenUsage {
2906 used_tokens: 40_000 + 20_000,
2907 max_tokens: 1_000_000,
2908 max_output_tokens: None,
2909 input_tokens: 40_000,
2910 output_tokens: 20_000,
2911 })
2912 );
2913 });
2914}
2915
2916#[gpui::test]
2917async fn test_truncate_second_message(cx: &mut TestAppContext) {
2918 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2919 let fake_model = model.as_fake();
2920
2921 thread
2922 .update(cx, |thread, cx| {
2923 thread.send(UserMessageId::new(), ["Message 1"], cx)
2924 })
2925 .unwrap();
2926 cx.run_until_parked();
2927 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2928 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2929 language_model::TokenUsage {
2930 input_tokens: 32_000,
2931 output_tokens: 16_000,
2932 cache_creation_input_tokens: 0,
2933 cache_read_input_tokens: 0,
2934 },
2935 ));
2936 fake_model.end_last_completion_stream();
2937 cx.run_until_parked();
2938
2939 let assert_first_message_state = |cx: &mut TestAppContext| {
2940 thread.clone().read_with(cx, |thread, _| {
2941 assert_eq!(
2942 thread.to_markdown(),
2943 indoc! {"
2944 ## User
2945
2946 Message 1
2947
2948 ## Assistant
2949
2950 Message 1 response
2951 "}
2952 );
2953
2954 assert_eq!(
2955 thread.latest_token_usage(),
2956 Some(acp_thread::TokenUsage {
2957 used_tokens: 32_000 + 16_000,
2958 max_tokens: 1_000_000,
2959 max_output_tokens: None,
2960 input_tokens: 32_000,
2961 output_tokens: 16_000,
2962 })
2963 );
2964 });
2965 };
2966
2967 assert_first_message_state(cx);
2968
2969 let second_message_id = UserMessageId::new();
2970 thread
2971 .update(cx, |thread, cx| {
2972 thread.send(second_message_id.clone(), ["Message 2"], cx)
2973 })
2974 .unwrap();
2975 cx.run_until_parked();
2976
2977 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2978 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2979 language_model::TokenUsage {
2980 input_tokens: 40_000,
2981 output_tokens: 20_000,
2982 cache_creation_input_tokens: 0,
2983 cache_read_input_tokens: 0,
2984 },
2985 ));
2986 fake_model.end_last_completion_stream();
2987 cx.run_until_parked();
2988
2989 thread.read_with(cx, |thread, _| {
2990 assert_eq!(
2991 thread.to_markdown(),
2992 indoc! {"
2993 ## User
2994
2995 Message 1
2996
2997 ## Assistant
2998
2999 Message 1 response
3000
3001 ## User
3002
3003 Message 2
3004
3005 ## Assistant
3006
3007 Message 2 response
3008 "}
3009 );
3010
3011 assert_eq!(
3012 thread.latest_token_usage(),
3013 Some(acp_thread::TokenUsage {
3014 used_tokens: 40_000 + 20_000,
3015 max_tokens: 1_000_000,
3016 max_output_tokens: None,
3017 input_tokens: 40_000,
3018 output_tokens: 20_000,
3019 })
3020 );
3021 });
3022
3023 thread
3024 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
3025 .unwrap();
3026 cx.run_until_parked();
3027
3028 assert_first_message_state(cx);
3029}
3030
3031#[gpui::test]
3032async fn test_title_generation(cx: &mut TestAppContext) {
3033 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3034 let fake_model = model.as_fake();
3035
3036 let summary_model = Arc::new(FakeLanguageModel::default());
3037 thread.update(cx, |thread, cx| {
3038 thread.set_summarization_model(Some(summary_model.clone()), cx)
3039 });
3040
3041 let send = thread
3042 .update(cx, |thread, cx| {
3043 thread.send(UserMessageId::new(), ["Hello"], cx)
3044 })
3045 .unwrap();
3046 cx.run_until_parked();
3047
3048 fake_model.send_last_completion_stream_text_chunk("Hey!");
3049 fake_model.end_last_completion_stream();
3050 cx.run_until_parked();
3051 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
3052
3053 // Ensure the summary model has been invoked to generate a title.
3054 summary_model.send_last_completion_stream_text_chunk("Hello ");
3055 summary_model.send_last_completion_stream_text_chunk("world\nG");
3056 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
3057 summary_model.end_last_completion_stream();
3058 send.collect::<Vec<_>>().await;
3059 cx.run_until_parked();
3060 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
3061
3062 // Send another message, ensuring no title is generated this time.
3063 let send = thread
3064 .update(cx, |thread, cx| {
3065 thread.send(UserMessageId::new(), ["Hello again"], cx)
3066 })
3067 .unwrap();
3068 cx.run_until_parked();
3069 fake_model.send_last_completion_stream_text_chunk("Hey again!");
3070 fake_model.end_last_completion_stream();
3071 cx.run_until_parked();
3072 assert_eq!(summary_model.pending_completions(), Vec::new());
3073 send.collect::<Vec<_>>().await;
3074 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
3075}
3076
3077#[gpui::test]
3078async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
3079 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3080 let fake_model = model.as_fake();
3081
3082 let _events = thread
3083 .update(cx, |thread, cx| {
3084 thread.add_tool(ToolRequiringPermission);
3085 thread.add_tool(EchoTool);
3086 thread.send(UserMessageId::new(), ["Hey!"], cx)
3087 })
3088 .unwrap();
3089 cx.run_until_parked();
3090
3091 let permission_tool_use = LanguageModelToolUse {
3092 id: "tool_id_1".into(),
3093 name: ToolRequiringPermission::NAME.into(),
3094 raw_input: "{}".into(),
3095 input: json!({}),
3096 is_input_complete: true,
3097 thought_signature: None,
3098 };
3099 let echo_tool_use = LanguageModelToolUse {
3100 id: "tool_id_2".into(),
3101 name: EchoTool::NAME.into(),
3102 raw_input: json!({"text": "test"}).to_string(),
3103 input: json!({"text": "test"}),
3104 is_input_complete: true,
3105 thought_signature: None,
3106 };
3107 fake_model.send_last_completion_stream_text_chunk("Hi!");
3108 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3109 permission_tool_use,
3110 ));
3111 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3112 echo_tool_use.clone(),
3113 ));
3114 fake_model.end_last_completion_stream();
3115 cx.run_until_parked();
3116
3117 // Ensure pending tools are skipped when building a request.
3118 let request = thread
3119 .read_with(cx, |thread, cx| {
3120 thread.build_completion_request(CompletionIntent::EditFile, cx)
3121 })
3122 .unwrap();
3123 assert_eq!(
3124 request.messages[1..],
3125 vec![
3126 LanguageModelRequestMessage {
3127 role: Role::User,
3128 content: vec!["Hey!".into()],
3129 cache: true,
3130 reasoning_details: None,
3131 },
3132 LanguageModelRequestMessage {
3133 role: Role::Assistant,
3134 content: vec![
3135 MessageContent::Text("Hi!".into()),
3136 MessageContent::ToolUse(echo_tool_use.clone())
3137 ],
3138 cache: false,
3139 reasoning_details: None,
3140 },
3141 LanguageModelRequestMessage {
3142 role: Role::User,
3143 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
3144 tool_use_id: echo_tool_use.id.clone(),
3145 tool_name: echo_tool_use.name,
3146 is_error: false,
3147 content: "test".into(),
3148 output: Some("test".into())
3149 })],
3150 cache: false,
3151 reasoning_details: None,
3152 },
3153 ],
3154 );
3155}
3156
3157#[gpui::test]
3158async fn test_agent_connection(cx: &mut TestAppContext) {
3159 cx.update(settings::init);
3160 let templates = Templates::new();
3161
3162 // Initialize language model system with test provider
3163 cx.update(|cx| {
3164 gpui_tokio::init(cx);
3165
3166 let http_client = FakeHttpClient::with_404_response();
3167 let clock = Arc::new(clock::FakeSystemClock::new());
3168 let client = Client::new(clock, http_client, cx);
3169 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3170 language_model::init(user_store.clone(), client.clone(), cx);
3171 language_models::init(user_store, client.clone(), cx);
3172 LanguageModelRegistry::test(cx);
3173 });
3174 cx.executor().forbid_parking();
3175
3176 // Create a project for new_thread
3177 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
3178 fake_fs.insert_tree(path!("/test"), json!({})).await;
3179 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
3180 let cwd = Path::new("/test");
3181 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3182
3183 // Create agent and connection
3184 let agent = NativeAgent::new(
3185 project.clone(),
3186 thread_store,
3187 templates.clone(),
3188 None,
3189 fake_fs.clone(),
3190 &mut cx.to_async(),
3191 )
3192 .await
3193 .unwrap();
3194 let connection = NativeAgentConnection(agent.clone());
3195
3196 // Create a thread using new_thread
3197 let connection_rc = Rc::new(connection.clone());
3198 let acp_thread = cx
3199 .update(|cx| connection_rc.new_session(project, cwd, cx))
3200 .await
3201 .expect("new_thread should succeed");
3202
3203 // Get the session_id from the AcpThread
3204 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3205
3206 // Test model_selector returns Some
3207 let selector_opt = connection.model_selector(&session_id);
3208 assert!(
3209 selector_opt.is_some(),
3210 "agent should always support ModelSelector"
3211 );
3212 let selector = selector_opt.unwrap();
3213
3214 // Test list_models
3215 let listed_models = cx
3216 .update(|cx| selector.list_models(cx))
3217 .await
3218 .expect("list_models should succeed");
3219 let AgentModelList::Grouped(listed_models) = listed_models else {
3220 panic!("Unexpected model list type");
3221 };
3222 assert!(!listed_models.is_empty(), "should have at least one model");
3223 assert_eq!(
3224 listed_models[&AgentModelGroupName("Fake".into())][0]
3225 .id
3226 .0
3227 .as_ref(),
3228 "fake/fake"
3229 );
3230
3231 // Test selected_model returns the default
3232 let model = cx
3233 .update(|cx| selector.selected_model(cx))
3234 .await
3235 .expect("selected_model should succeed");
3236 let model = cx
3237 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
3238 .unwrap();
3239 let model = model.as_fake();
3240 assert_eq!(model.id().0, "fake", "should return default model");
3241
3242 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
3243 cx.run_until_parked();
3244 model.send_last_completion_stream_text_chunk("def");
3245 cx.run_until_parked();
3246 acp_thread.read_with(cx, |thread, cx| {
3247 assert_eq!(
3248 thread.to_markdown(cx),
3249 indoc! {"
3250 ## User
3251
3252 abc
3253
3254 ## Assistant
3255
3256 def
3257
3258 "}
3259 )
3260 });
3261
3262 // Test cancel
3263 cx.update(|cx| connection.cancel(&session_id, cx));
3264 request.await.expect("prompt should fail gracefully");
3265
3266 // Explicitly close the session and drop the ACP thread.
3267 cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
3268 .await
3269 .unwrap();
3270 drop(acp_thread);
3271 let result = cx
3272 .update(|cx| {
3273 connection.prompt(
3274 Some(acp_thread::UserMessageId::new()),
3275 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
3276 cx,
3277 )
3278 })
3279 .await;
3280 assert_eq!(
3281 result.as_ref().unwrap_err().to_string(),
3282 "Session not found",
3283 "unexpected result: {:?}",
3284 result
3285 );
3286}
3287
3288#[gpui::test]
3289async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
3290 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3291 thread.update(cx, |thread, _cx| thread.add_tool(EchoTool));
3292 let fake_model = model.as_fake();
3293
3294 let mut events = thread
3295 .update(cx, |thread, cx| {
3296 thread.send(UserMessageId::new(), ["Echo something"], cx)
3297 })
3298 .unwrap();
3299 cx.run_until_parked();
3300
3301 // Simulate streaming partial input.
3302 let input = json!({});
3303 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3304 LanguageModelToolUse {
3305 id: "1".into(),
3306 name: EchoTool::NAME.into(),
3307 raw_input: input.to_string(),
3308 input,
3309 is_input_complete: false,
3310 thought_signature: None,
3311 },
3312 ));
3313
3314 // Input streaming completed
3315 let input = json!({ "text": "Hello!" });
3316 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3317 LanguageModelToolUse {
3318 id: "1".into(),
3319 name: "echo".into(),
3320 raw_input: input.to_string(),
3321 input,
3322 is_input_complete: true,
3323 thought_signature: None,
3324 },
3325 ));
3326 fake_model.end_last_completion_stream();
3327 cx.run_until_parked();
3328
3329 let tool_call = expect_tool_call(&mut events).await;
3330 assert_eq!(
3331 tool_call,
3332 acp::ToolCall::new("1", "Echo")
3333 .raw_input(json!({}))
3334 .meta(acp::Meta::from_iter([("tool_name".into(), "echo".into())]))
3335 );
3336 let update = expect_tool_call_update_fields(&mut events).await;
3337 assert_eq!(
3338 update,
3339 acp::ToolCallUpdate::new(
3340 "1",
3341 acp::ToolCallUpdateFields::new()
3342 .title("Echo")
3343 .kind(acp::ToolKind::Other)
3344 .raw_input(json!({ "text": "Hello!"}))
3345 )
3346 );
3347 let update = expect_tool_call_update_fields(&mut events).await;
3348 assert_eq!(
3349 update,
3350 acp::ToolCallUpdate::new(
3351 "1",
3352 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3353 )
3354 );
3355 let update = expect_tool_call_update_fields(&mut events).await;
3356 assert_eq!(
3357 update,
3358 acp::ToolCallUpdate::new(
3359 "1",
3360 acp::ToolCallUpdateFields::new()
3361 .status(acp::ToolCallStatus::Completed)
3362 .raw_output("Hello!")
3363 )
3364 );
3365}
3366
3367#[gpui::test]
3368async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3369 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3370 let fake_model = model.as_fake();
3371
3372 let mut events = thread
3373 .update(cx, |thread, cx| {
3374 thread.send(UserMessageId::new(), ["Hello!"], cx)
3375 })
3376 .unwrap();
3377 cx.run_until_parked();
3378
3379 fake_model.send_last_completion_stream_text_chunk("Hey!");
3380 fake_model.end_last_completion_stream();
3381
3382 let mut retry_events = Vec::new();
3383 while let Some(Ok(event)) = events.next().await {
3384 match event {
3385 ThreadEvent::Retry(retry_status) => {
3386 retry_events.push(retry_status);
3387 }
3388 ThreadEvent::Stop(..) => break,
3389 _ => {}
3390 }
3391 }
3392
3393 assert_eq!(retry_events.len(), 0);
3394 thread.read_with(cx, |thread, _cx| {
3395 assert_eq!(
3396 thread.to_markdown(),
3397 indoc! {"
3398 ## User
3399
3400 Hello!
3401
3402 ## Assistant
3403
3404 Hey!
3405 "}
3406 )
3407 });
3408}
3409
3410#[gpui::test]
3411async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3412 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3413 let fake_model = model.as_fake();
3414
3415 let mut events = thread
3416 .update(cx, |thread, cx| {
3417 thread.send(UserMessageId::new(), ["Hello!"], cx)
3418 })
3419 .unwrap();
3420 cx.run_until_parked();
3421
3422 fake_model.send_last_completion_stream_text_chunk("Hey,");
3423 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3424 provider: LanguageModelProviderName::new("Anthropic"),
3425 retry_after: Some(Duration::from_secs(3)),
3426 });
3427 fake_model.end_last_completion_stream();
3428
3429 cx.executor().advance_clock(Duration::from_secs(3));
3430 cx.run_until_parked();
3431
3432 fake_model.send_last_completion_stream_text_chunk("there!");
3433 fake_model.end_last_completion_stream();
3434 cx.run_until_parked();
3435
3436 let mut retry_events = Vec::new();
3437 while let Some(Ok(event)) = events.next().await {
3438 match event {
3439 ThreadEvent::Retry(retry_status) => {
3440 retry_events.push(retry_status);
3441 }
3442 ThreadEvent::Stop(..) => break,
3443 _ => {}
3444 }
3445 }
3446
3447 assert_eq!(retry_events.len(), 1);
3448 assert!(matches!(
3449 retry_events[0],
3450 acp_thread::RetryStatus { attempt: 1, .. }
3451 ));
3452 thread.read_with(cx, |thread, _cx| {
3453 assert_eq!(
3454 thread.to_markdown(),
3455 indoc! {"
3456 ## User
3457
3458 Hello!
3459
3460 ## Assistant
3461
3462 Hey,
3463
3464 [resume]
3465
3466 ## Assistant
3467
3468 there!
3469 "}
3470 )
3471 });
3472}
3473
3474#[gpui::test]
3475async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3476 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3477 let fake_model = model.as_fake();
3478
3479 let events = thread
3480 .update(cx, |thread, cx| {
3481 thread.add_tool(EchoTool);
3482 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3483 })
3484 .unwrap();
3485 cx.run_until_parked();
3486
3487 let tool_use_1 = LanguageModelToolUse {
3488 id: "tool_1".into(),
3489 name: EchoTool::NAME.into(),
3490 raw_input: json!({"text": "test"}).to_string(),
3491 input: json!({"text": "test"}),
3492 is_input_complete: true,
3493 thought_signature: None,
3494 };
3495 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3496 tool_use_1.clone(),
3497 ));
3498 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3499 provider: LanguageModelProviderName::new("Anthropic"),
3500 retry_after: Some(Duration::from_secs(3)),
3501 });
3502 fake_model.end_last_completion_stream();
3503
3504 cx.executor().advance_clock(Duration::from_secs(3));
3505 let completion = fake_model.pending_completions().pop().unwrap();
3506 assert_eq!(
3507 completion.messages[1..],
3508 vec![
3509 LanguageModelRequestMessage {
3510 role: Role::User,
3511 content: vec!["Call the echo tool!".into()],
3512 cache: false,
3513 reasoning_details: None,
3514 },
3515 LanguageModelRequestMessage {
3516 role: Role::Assistant,
3517 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3518 cache: false,
3519 reasoning_details: None,
3520 },
3521 LanguageModelRequestMessage {
3522 role: Role::User,
3523 content: vec![language_model::MessageContent::ToolResult(
3524 LanguageModelToolResult {
3525 tool_use_id: tool_use_1.id.clone(),
3526 tool_name: tool_use_1.name.clone(),
3527 is_error: false,
3528 content: "test".into(),
3529 output: Some("test".into())
3530 }
3531 )],
3532 cache: true,
3533 reasoning_details: None,
3534 },
3535 ]
3536 );
3537
3538 fake_model.send_last_completion_stream_text_chunk("Done");
3539 fake_model.end_last_completion_stream();
3540 cx.run_until_parked();
3541 events.collect::<Vec<_>>().await;
3542 thread.read_with(cx, |thread, _cx| {
3543 assert_eq!(
3544 thread.last_received_or_pending_message(),
3545 Some(Message::Agent(AgentMessage {
3546 content: vec![AgentMessageContent::Text("Done".into())],
3547 tool_results: IndexMap::default(),
3548 reasoning_details: None,
3549 }))
3550 );
3551 })
3552}
3553
3554#[gpui::test]
3555async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3556 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3557 let fake_model = model.as_fake();
3558
3559 let mut events = thread
3560 .update(cx, |thread, cx| {
3561 thread.send(UserMessageId::new(), ["Hello!"], cx)
3562 })
3563 .unwrap();
3564 cx.run_until_parked();
3565
3566 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3567 fake_model.send_last_completion_stream_error(
3568 LanguageModelCompletionError::ServerOverloaded {
3569 provider: LanguageModelProviderName::new("Anthropic"),
3570 retry_after: Some(Duration::from_secs(3)),
3571 },
3572 );
3573 fake_model.end_last_completion_stream();
3574 cx.executor().advance_clock(Duration::from_secs(3));
3575 cx.run_until_parked();
3576 }
3577
3578 let mut errors = Vec::new();
3579 let mut retry_events = Vec::new();
3580 while let Some(event) = events.next().await {
3581 match event {
3582 Ok(ThreadEvent::Retry(retry_status)) => {
3583 retry_events.push(retry_status);
3584 }
3585 Ok(ThreadEvent::Stop(..)) => break,
3586 Err(error) => errors.push(error),
3587 _ => {}
3588 }
3589 }
3590
3591 assert_eq!(
3592 retry_events.len(),
3593 crate::thread::MAX_RETRY_ATTEMPTS as usize
3594 );
3595 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3596 assert_eq!(retry_events[i].attempt, i + 1);
3597 }
3598 assert_eq!(errors.len(), 1);
3599 let error = errors[0]
3600 .downcast_ref::<LanguageModelCompletionError>()
3601 .unwrap();
3602 assert!(matches!(
3603 error,
3604 LanguageModelCompletionError::ServerOverloaded { .. }
3605 ));
3606}
3607
3608#[gpui::test]
3609async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input(
3610 cx: &mut TestAppContext,
3611) {
3612 init_test(cx);
3613 always_allow_tools(cx);
3614
3615 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3616 let fake_model = model.as_fake();
3617
3618 thread.update(cx, |thread, _cx| {
3619 thread.add_tool(StreamingEchoTool::new());
3620 });
3621
3622 let _events = thread
3623 .update(cx, |thread, cx| {
3624 thread.send(UserMessageId::new(), ["Use the streaming_echo tool"], cx)
3625 })
3626 .unwrap();
3627 cx.run_until_parked();
3628
3629 // Send a partial tool use (is_input_complete = false), simulating the LLM
3630 // streaming input for a tool.
3631 let tool_use = LanguageModelToolUse {
3632 id: "tool_1".into(),
3633 name: "streaming_echo".into(),
3634 raw_input: r#"{"text": "partial"}"#.into(),
3635 input: json!({"text": "partial"}),
3636 is_input_complete: false,
3637 thought_signature: None,
3638 };
3639 fake_model
3640 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
3641 cx.run_until_parked();
3642
3643 // Send a stream error WITHOUT ever sending is_input_complete = true.
3644 // Before the fix, this would deadlock: the tool waits for more partials
3645 // (or cancellation), run_turn_internal waits for the tool, and the sender
3646 // keeping the channel open lives inside RunningTurn.
3647 fake_model.send_last_completion_stream_error(
3648 LanguageModelCompletionError::UpstreamProviderError {
3649 message: "Internal server error".to_string(),
3650 status: http_client::StatusCode::INTERNAL_SERVER_ERROR,
3651 retry_after: None,
3652 },
3653 );
3654 fake_model.end_last_completion_stream();
3655
3656 // Advance past the retry delay so run_turn_internal retries.
3657 cx.executor().advance_clock(Duration::from_secs(5));
3658 cx.run_until_parked();
3659
3660 // The retry request should contain the streaming tool's error result,
3661 // proving the tool terminated and its result was forwarded.
3662 let completion = fake_model
3663 .pending_completions()
3664 .pop()
3665 .expect("No running turn");
3666 assert_eq!(
3667 completion.messages[1..],
3668 vec![
3669 LanguageModelRequestMessage {
3670 role: Role::User,
3671 content: vec!["Use the streaming_echo tool".into()],
3672 cache: false,
3673 reasoning_details: None,
3674 },
3675 LanguageModelRequestMessage {
3676 role: Role::Assistant,
3677 content: vec![language_model::MessageContent::ToolUse(tool_use.clone())],
3678 cache: false,
3679 reasoning_details: None,
3680 },
3681 LanguageModelRequestMessage {
3682 role: Role::User,
3683 content: vec![language_model::MessageContent::ToolResult(
3684 LanguageModelToolResult {
3685 tool_use_id: tool_use.id.clone(),
3686 tool_name: tool_use.name,
3687 is_error: true,
3688 content: "Failed to receive tool input: tool input was not fully received"
3689 .into(),
3690 output: Some(
3691 "Failed to receive tool input: tool input was not fully received"
3692 .into()
3693 ),
3694 }
3695 )],
3696 cache: true,
3697 reasoning_details: None,
3698 },
3699 ]
3700 );
3701
3702 // Finish the retry round so the turn completes cleanly.
3703 fake_model.send_last_completion_stream_text_chunk("Done");
3704 fake_model.end_last_completion_stream();
3705 cx.run_until_parked();
3706
3707 thread.read_with(cx, |thread, _cx| {
3708 assert!(
3709 thread.is_turn_complete(),
3710 "Thread should not be stuck; the turn should have completed",
3711 );
3712 });
3713}
3714
3715/// Filters out the stop events for asserting against in tests
3716fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3717 result_events
3718 .into_iter()
3719 .filter_map(|event| match event.unwrap() {
3720 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3721 _ => None,
3722 })
3723 .collect()
3724}
3725
3726struct ThreadTest {
3727 model: Arc<dyn LanguageModel>,
3728 thread: Entity<Thread>,
3729 project_context: Entity<ProjectContext>,
3730 context_server_store: Entity<ContextServerStore>,
3731 fs: Arc<FakeFs>,
3732}
3733
3734enum TestModel {
3735 Sonnet4,
3736 Fake,
3737}
3738
3739impl TestModel {
3740 fn id(&self) -> LanguageModelId {
3741 match self {
3742 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3743 TestModel::Fake => unreachable!(),
3744 }
3745 }
3746}
3747
3748async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3749 cx.executor().allow_parking();
3750
3751 let fs = FakeFs::new(cx.background_executor.clone());
3752 fs.create_dir(paths::settings_file().parent().unwrap())
3753 .await
3754 .unwrap();
3755 fs.insert_file(
3756 paths::settings_file(),
3757 json!({
3758 "agent": {
3759 "default_profile": "test-profile",
3760 "profiles": {
3761 "test-profile": {
3762 "name": "Test Profile",
3763 "tools": {
3764 EchoTool::NAME: true,
3765 DelayTool::NAME: true,
3766 WordListTool::NAME: true,
3767 ToolRequiringPermission::NAME: true,
3768 InfiniteTool::NAME: true,
3769 CancellationAwareTool::NAME: true,
3770 StreamingEchoTool::NAME: true,
3771 StreamingFailingEchoTool::NAME: true,
3772 TerminalTool::NAME: true,
3773 }
3774 }
3775 }
3776 }
3777 })
3778 .to_string()
3779 .into_bytes(),
3780 )
3781 .await;
3782
3783 cx.update(|cx| {
3784 settings::init(cx);
3785
3786 match model {
3787 TestModel::Fake => {}
3788 TestModel::Sonnet4 => {
3789 gpui_tokio::init(cx);
3790 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3791 cx.set_http_client(Arc::new(http_client));
3792 let client = Client::production(cx);
3793 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3794 language_model::init(user_store.clone(), client.clone(), cx);
3795 language_models::init(user_store, client.clone(), cx);
3796 }
3797 };
3798
3799 watch_settings(fs.clone(), cx);
3800 });
3801
3802 let templates = Templates::new();
3803
3804 fs.insert_tree(path!("/test"), json!({})).await;
3805 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3806
3807 let model = cx
3808 .update(|cx| {
3809 if let TestModel::Fake = model {
3810 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3811 } else {
3812 let model_id = model.id();
3813 let models = LanguageModelRegistry::read_global(cx);
3814 let model = models
3815 .available_models(cx)
3816 .find(|model| model.id() == model_id)
3817 .unwrap();
3818
3819 let provider = models.provider(&model.provider_id()).unwrap();
3820 let authenticated = provider.authenticate(cx);
3821
3822 cx.spawn(async move |_cx| {
3823 authenticated.await.unwrap();
3824 model
3825 })
3826 }
3827 })
3828 .await;
3829
3830 let project_context = cx.new(|_cx| ProjectContext::default());
3831 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3832 let context_server_registry =
3833 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3834 let thread = cx.new(|cx| {
3835 Thread::new(
3836 project,
3837 project_context.clone(),
3838 context_server_registry,
3839 templates,
3840 Some(model.clone()),
3841 cx,
3842 )
3843 });
3844 ThreadTest {
3845 model,
3846 thread,
3847 project_context,
3848 context_server_store,
3849 fs,
3850 }
3851}
3852
3853#[cfg(test)]
3854#[ctor::ctor]
3855fn init_logger() {
3856 if std::env::var("RUST_LOG").is_ok() {
3857 env_logger::init();
3858 }
3859}
3860
3861fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3862 let fs = fs.clone();
3863 cx.spawn({
3864 async move |cx| {
3865 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3866 cx.background_executor(),
3867 fs,
3868 paths::settings_file().clone(),
3869 );
3870 let _watcher_task = watcher_task;
3871
3872 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3873 cx.update(|cx| {
3874 SettingsStore::update_global(cx, |settings, cx| {
3875 settings.set_user_settings(&new_settings_content, cx)
3876 })
3877 })
3878 .ok();
3879 }
3880 }
3881 })
3882 .detach();
3883}
3884
3885fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3886 completion
3887 .tools
3888 .iter()
3889 .map(|tool| tool.name.clone())
3890 .collect()
3891}
3892
3893fn setup_context_server(
3894 name: &'static str,
3895 tools: Vec<context_server::types::Tool>,
3896 context_server_store: &Entity<ContextServerStore>,
3897 cx: &mut TestAppContext,
3898) -> mpsc::UnboundedReceiver<(
3899 context_server::types::CallToolParams,
3900 oneshot::Sender<context_server::types::CallToolResponse>,
3901)> {
3902 cx.update(|cx| {
3903 let mut settings = ProjectSettings::get_global(cx).clone();
3904 settings.context_servers.insert(
3905 name.into(),
3906 project::project_settings::ContextServerSettings::Stdio {
3907 enabled: true,
3908 remote: false,
3909 command: ContextServerCommand {
3910 path: "somebinary".into(),
3911 args: Vec::new(),
3912 env: None,
3913 timeout: None,
3914 },
3915 },
3916 );
3917 ProjectSettings::override_global(settings, cx);
3918 });
3919
3920 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3921 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3922 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3923 context_server::types::InitializeResponse {
3924 protocol_version: context_server::types::ProtocolVersion(
3925 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3926 ),
3927 server_info: context_server::types::Implementation {
3928 name: name.into(),
3929 version: "1.0.0".to_string(),
3930 },
3931 capabilities: context_server::types::ServerCapabilities {
3932 tools: Some(context_server::types::ToolsCapabilities {
3933 list_changed: Some(true),
3934 }),
3935 ..Default::default()
3936 },
3937 meta: None,
3938 }
3939 })
3940 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3941 let tools = tools.clone();
3942 async move {
3943 context_server::types::ListToolsResponse {
3944 tools,
3945 next_cursor: None,
3946 meta: None,
3947 }
3948 }
3949 })
3950 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3951 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3952 async move {
3953 let (response_tx, response_rx) = oneshot::channel();
3954 mcp_tool_calls_tx
3955 .unbounded_send((params, response_tx))
3956 .unwrap();
3957 response_rx.await.unwrap()
3958 }
3959 });
3960 context_server_store.update(cx, |store, cx| {
3961 store.start_server(
3962 Arc::new(ContextServer::new(
3963 ContextServerId(name.into()),
3964 Arc::new(fake_transport),
3965 )),
3966 cx,
3967 );
3968 });
3969 cx.run_until_parked();
3970 mcp_tool_calls_rx
3971}
3972
3973#[gpui::test]
3974async fn test_tokens_before_message(cx: &mut TestAppContext) {
3975 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3976 let fake_model = model.as_fake();
3977
3978 // First message
3979 let message_1_id = UserMessageId::new();
3980 thread
3981 .update(cx, |thread, cx| {
3982 thread.send(message_1_id.clone(), ["First message"], cx)
3983 })
3984 .unwrap();
3985 cx.run_until_parked();
3986
3987 // Before any response, tokens_before_message should return None for first message
3988 thread.read_with(cx, |thread, _| {
3989 assert_eq!(
3990 thread.tokens_before_message(&message_1_id),
3991 None,
3992 "First message should have no tokens before it"
3993 );
3994 });
3995
3996 // Complete first message with usage
3997 fake_model.send_last_completion_stream_text_chunk("Response 1");
3998 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3999 language_model::TokenUsage {
4000 input_tokens: 100,
4001 output_tokens: 50,
4002 cache_creation_input_tokens: 0,
4003 cache_read_input_tokens: 0,
4004 },
4005 ));
4006 fake_model.end_last_completion_stream();
4007 cx.run_until_parked();
4008
4009 // First message still has no tokens before it
4010 thread.read_with(cx, |thread, _| {
4011 assert_eq!(
4012 thread.tokens_before_message(&message_1_id),
4013 None,
4014 "First message should still have no tokens before it after response"
4015 );
4016 });
4017
4018 // Second message
4019 let message_2_id = UserMessageId::new();
4020 thread
4021 .update(cx, |thread, cx| {
4022 thread.send(message_2_id.clone(), ["Second message"], cx)
4023 })
4024 .unwrap();
4025 cx.run_until_parked();
4026
4027 // Second message should have first message's input tokens before it
4028 thread.read_with(cx, |thread, _| {
4029 assert_eq!(
4030 thread.tokens_before_message(&message_2_id),
4031 Some(100),
4032 "Second message should have 100 tokens before it (from first request)"
4033 );
4034 });
4035
4036 // Complete second message
4037 fake_model.send_last_completion_stream_text_chunk("Response 2");
4038 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
4039 language_model::TokenUsage {
4040 input_tokens: 250, // Total for this request (includes previous context)
4041 output_tokens: 75,
4042 cache_creation_input_tokens: 0,
4043 cache_read_input_tokens: 0,
4044 },
4045 ));
4046 fake_model.end_last_completion_stream();
4047 cx.run_until_parked();
4048
4049 // Third message
4050 let message_3_id = UserMessageId::new();
4051 thread
4052 .update(cx, |thread, cx| {
4053 thread.send(message_3_id.clone(), ["Third message"], cx)
4054 })
4055 .unwrap();
4056 cx.run_until_parked();
4057
4058 // Third message should have second message's input tokens (250) before it
4059 thread.read_with(cx, |thread, _| {
4060 assert_eq!(
4061 thread.tokens_before_message(&message_3_id),
4062 Some(250),
4063 "Third message should have 250 tokens before it (from second request)"
4064 );
4065 // Second message should still have 100
4066 assert_eq!(
4067 thread.tokens_before_message(&message_2_id),
4068 Some(100),
4069 "Second message should still have 100 tokens before it"
4070 );
4071 // First message still has none
4072 assert_eq!(
4073 thread.tokens_before_message(&message_1_id),
4074 None,
4075 "First message should still have no tokens before it"
4076 );
4077 });
4078}
4079
4080#[gpui::test]
4081async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
4082 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4083 let fake_model = model.as_fake();
4084
4085 // Set up three messages with responses
4086 let message_1_id = UserMessageId::new();
4087 thread
4088 .update(cx, |thread, cx| {
4089 thread.send(message_1_id.clone(), ["Message 1"], cx)
4090 })
4091 .unwrap();
4092 cx.run_until_parked();
4093 fake_model.send_last_completion_stream_text_chunk("Response 1");
4094 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
4095 language_model::TokenUsage {
4096 input_tokens: 100,
4097 output_tokens: 50,
4098 cache_creation_input_tokens: 0,
4099 cache_read_input_tokens: 0,
4100 },
4101 ));
4102 fake_model.end_last_completion_stream();
4103 cx.run_until_parked();
4104
4105 let message_2_id = UserMessageId::new();
4106 thread
4107 .update(cx, |thread, cx| {
4108 thread.send(message_2_id.clone(), ["Message 2"], cx)
4109 })
4110 .unwrap();
4111 cx.run_until_parked();
4112 fake_model.send_last_completion_stream_text_chunk("Response 2");
4113 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
4114 language_model::TokenUsage {
4115 input_tokens: 250,
4116 output_tokens: 75,
4117 cache_creation_input_tokens: 0,
4118 cache_read_input_tokens: 0,
4119 },
4120 ));
4121 fake_model.end_last_completion_stream();
4122 cx.run_until_parked();
4123
4124 // Verify initial state
4125 thread.read_with(cx, |thread, _| {
4126 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
4127 });
4128
4129 // Truncate at message 2 (removes message 2 and everything after)
4130 thread
4131 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
4132 .unwrap();
4133 cx.run_until_parked();
4134
4135 // After truncation, message_2_id no longer exists, so lookup should return None
4136 thread.read_with(cx, |thread, _| {
4137 assert_eq!(
4138 thread.tokens_before_message(&message_2_id),
4139 None,
4140 "After truncation, message 2 no longer exists"
4141 );
4142 // Message 1 still exists but has no tokens before it
4143 assert_eq!(
4144 thread.tokens_before_message(&message_1_id),
4145 None,
4146 "First message still has no tokens before it"
4147 );
4148 });
4149}
4150
4151#[gpui::test]
4152async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
4153 init_test(cx);
4154
4155 let fs = FakeFs::new(cx.executor());
4156 fs.insert_tree("/root", json!({})).await;
4157 let project = Project::test(fs, ["/root".as_ref()], cx).await;
4158
4159 // Test 1: Deny rule blocks command
4160 {
4161 let environment = Rc::new(cx.update(|cx| {
4162 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4163 }));
4164
4165 cx.update(|cx| {
4166 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4167 settings.tool_permissions.tools.insert(
4168 TerminalTool::NAME.into(),
4169 agent_settings::ToolRules {
4170 default: Some(settings::ToolPermissionMode::Confirm),
4171 always_allow: vec![],
4172 always_deny: vec![
4173 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
4174 ],
4175 always_confirm: vec![],
4176 invalid_patterns: vec![],
4177 },
4178 );
4179 agent_settings::AgentSettings::override_global(settings, cx);
4180 });
4181
4182 #[allow(clippy::arc_with_non_send_sync)]
4183 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4184 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4185
4186 let task = cx.update(|cx| {
4187 tool.run(
4188 ToolInput::resolved(crate::TerminalToolInput {
4189 command: "rm -rf /".to_string(),
4190 cd: ".".to_string(),
4191 timeout_ms: None,
4192 }),
4193 event_stream,
4194 cx,
4195 )
4196 });
4197
4198 let result = task.await;
4199 assert!(
4200 result.is_err(),
4201 "expected command to be blocked by deny rule"
4202 );
4203 let err_msg = result.unwrap_err().to_lowercase();
4204 assert!(
4205 err_msg.contains("blocked"),
4206 "error should mention the command was blocked"
4207 );
4208 }
4209
4210 // Test 2: Allow rule skips confirmation (and overrides default: Deny)
4211 {
4212 let environment = Rc::new(cx.update(|cx| {
4213 FakeThreadEnvironment::default()
4214 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4215 }));
4216
4217 cx.update(|cx| {
4218 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4219 settings.tool_permissions.tools.insert(
4220 TerminalTool::NAME.into(),
4221 agent_settings::ToolRules {
4222 default: Some(settings::ToolPermissionMode::Deny),
4223 always_allow: vec![
4224 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
4225 ],
4226 always_deny: vec![],
4227 always_confirm: vec![],
4228 invalid_patterns: vec![],
4229 },
4230 );
4231 agent_settings::AgentSettings::override_global(settings, cx);
4232 });
4233
4234 #[allow(clippy::arc_with_non_send_sync)]
4235 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4236 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4237
4238 let task = cx.update(|cx| {
4239 tool.run(
4240 ToolInput::resolved(crate::TerminalToolInput {
4241 command: "echo hello".to_string(),
4242 cd: ".".to_string(),
4243 timeout_ms: None,
4244 }),
4245 event_stream,
4246 cx,
4247 )
4248 });
4249
4250 let update = rx.expect_update_fields().await;
4251 assert!(
4252 update.content.iter().any(|blocks| {
4253 blocks
4254 .iter()
4255 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
4256 }),
4257 "expected terminal content (allow rule should skip confirmation and override default deny)"
4258 );
4259
4260 let result = task.await;
4261 assert!(
4262 result.is_ok(),
4263 "expected command to succeed without confirmation"
4264 );
4265 }
4266
4267 // Test 3: global default: allow does NOT override always_confirm patterns
4268 {
4269 let environment = Rc::new(cx.update(|cx| {
4270 FakeThreadEnvironment::default()
4271 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4272 }));
4273
4274 cx.update(|cx| {
4275 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4276 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4277 settings.tool_permissions.tools.insert(
4278 TerminalTool::NAME.into(),
4279 agent_settings::ToolRules {
4280 default: Some(settings::ToolPermissionMode::Allow),
4281 always_allow: vec![],
4282 always_deny: vec![],
4283 always_confirm: vec![
4284 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
4285 ],
4286 invalid_patterns: vec![],
4287 },
4288 );
4289 agent_settings::AgentSettings::override_global(settings, cx);
4290 });
4291
4292 #[allow(clippy::arc_with_non_send_sync)]
4293 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4294 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4295
4296 let _task = cx.update(|cx| {
4297 tool.run(
4298 ToolInput::resolved(crate::TerminalToolInput {
4299 command: "sudo rm file".to_string(),
4300 cd: ".".to_string(),
4301 timeout_ms: None,
4302 }),
4303 event_stream,
4304 cx,
4305 )
4306 });
4307
4308 // With global default: allow, confirm patterns are still respected
4309 // The expect_authorization() call will panic if no authorization is requested,
4310 // which validates that the confirm pattern still triggers confirmation
4311 let _auth = rx.expect_authorization().await;
4312
4313 drop(_task);
4314 }
4315
4316 // Test 4: tool-specific default: deny is respected even with global default: allow
4317 {
4318 let environment = Rc::new(cx.update(|cx| {
4319 FakeThreadEnvironment::default()
4320 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4321 }));
4322
4323 cx.update(|cx| {
4324 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4325 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4326 settings.tool_permissions.tools.insert(
4327 TerminalTool::NAME.into(),
4328 agent_settings::ToolRules {
4329 default: Some(settings::ToolPermissionMode::Deny),
4330 always_allow: vec![],
4331 always_deny: vec![],
4332 always_confirm: vec![],
4333 invalid_patterns: vec![],
4334 },
4335 );
4336 agent_settings::AgentSettings::override_global(settings, cx);
4337 });
4338
4339 #[allow(clippy::arc_with_non_send_sync)]
4340 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4341 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4342
4343 let task = cx.update(|cx| {
4344 tool.run(
4345 ToolInput::resolved(crate::TerminalToolInput {
4346 command: "echo hello".to_string(),
4347 cd: ".".to_string(),
4348 timeout_ms: None,
4349 }),
4350 event_stream,
4351 cx,
4352 )
4353 });
4354
4355 // tool-specific default: deny is respected even with global default: allow
4356 let result = task.await;
4357 assert!(
4358 result.is_err(),
4359 "expected command to be blocked by tool-specific deny default"
4360 );
4361 let err_msg = result.unwrap_err().to_lowercase();
4362 assert!(
4363 err_msg.contains("disabled"),
4364 "error should mention the tool is disabled, got: {err_msg}"
4365 );
4366 }
4367}
4368
4369#[gpui::test]
4370async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
4371 init_test(cx);
4372 cx.update(|cx| {
4373 LanguageModelRegistry::test(cx);
4374 });
4375 cx.update(|cx| {
4376 cx.update_flags(true, vec!["subagents".to_string()]);
4377 });
4378
4379 let fs = FakeFs::new(cx.executor());
4380 fs.insert_tree(
4381 "/",
4382 json!({
4383 "a": {
4384 "b.md": "Lorem"
4385 }
4386 }),
4387 )
4388 .await;
4389 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4390 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4391 let agent = NativeAgent::new(
4392 project.clone(),
4393 thread_store.clone(),
4394 Templates::new(),
4395 None,
4396 fs.clone(),
4397 &mut cx.to_async(),
4398 )
4399 .await
4400 .unwrap();
4401 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4402
4403 let acp_thread = cx
4404 .update(|cx| {
4405 connection
4406 .clone()
4407 .new_session(project.clone(), Path::new(""), cx)
4408 })
4409 .await
4410 .unwrap();
4411 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4412 let thread = agent.read_with(cx, |agent, _| {
4413 agent.sessions.get(&session_id).unwrap().thread.clone()
4414 });
4415 let model = Arc::new(FakeLanguageModel::default());
4416
4417 // Ensure empty threads are not saved, even if they get mutated.
4418 thread.update(cx, |thread, cx| {
4419 thread.set_model(model.clone(), cx);
4420 });
4421 cx.run_until_parked();
4422
4423 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4424 cx.run_until_parked();
4425 model.send_last_completion_stream_text_chunk("spawning subagent");
4426 let subagent_tool_input = SpawnAgentToolInput {
4427 label: "label".to_string(),
4428 message: "subagent task prompt".to_string(),
4429 session_id: None,
4430 };
4431 let subagent_tool_use = LanguageModelToolUse {
4432 id: "subagent_1".into(),
4433 name: SpawnAgentTool::NAME.into(),
4434 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4435 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4436 is_input_complete: true,
4437 thought_signature: None,
4438 };
4439 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4440 subagent_tool_use,
4441 ));
4442 model.end_last_completion_stream();
4443
4444 cx.run_until_parked();
4445
4446 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4447 thread
4448 .running_subagent_ids(cx)
4449 .get(0)
4450 .expect("subagent thread should be running")
4451 .clone()
4452 });
4453
4454 let subagent_thread = agent.read_with(cx, |agent, _cx| {
4455 agent
4456 .sessions
4457 .get(&subagent_session_id)
4458 .expect("subagent session should exist")
4459 .acp_thread
4460 .clone()
4461 });
4462
4463 model.send_last_completion_stream_text_chunk("subagent task response");
4464 model.end_last_completion_stream();
4465
4466 cx.run_until_parked();
4467
4468 assert_eq!(
4469 subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4470 indoc! {"
4471 ## User
4472
4473 subagent task prompt
4474
4475 ## Assistant
4476
4477 subagent task response
4478
4479 "}
4480 );
4481
4482 model.send_last_completion_stream_text_chunk("Response");
4483 model.end_last_completion_stream();
4484
4485 send.await.unwrap();
4486
4487 assert_eq!(
4488 acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4489 indoc! {r#"
4490 ## User
4491
4492 Prompt
4493
4494 ## Assistant
4495
4496 spawning subagent
4497
4498 **Tool Call: label**
4499 Status: Completed
4500
4501 subagent task response
4502
4503 ## Assistant
4504
4505 Response
4506
4507 "#},
4508 );
4509}
4510
4511#[gpui::test]
4512async fn test_subagent_tool_output_does_not_include_thinking(cx: &mut TestAppContext) {
4513 init_test(cx);
4514 cx.update(|cx| {
4515 LanguageModelRegistry::test(cx);
4516 });
4517 cx.update(|cx| {
4518 cx.update_flags(true, vec!["subagents".to_string()]);
4519 });
4520
4521 let fs = FakeFs::new(cx.executor());
4522 fs.insert_tree(
4523 "/",
4524 json!({
4525 "a": {
4526 "b.md": "Lorem"
4527 }
4528 }),
4529 )
4530 .await;
4531 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4532 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4533 let agent = NativeAgent::new(
4534 project.clone(),
4535 thread_store.clone(),
4536 Templates::new(),
4537 None,
4538 fs.clone(),
4539 &mut cx.to_async(),
4540 )
4541 .await
4542 .unwrap();
4543 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4544
4545 let acp_thread = cx
4546 .update(|cx| {
4547 connection
4548 .clone()
4549 .new_session(project.clone(), Path::new(""), cx)
4550 })
4551 .await
4552 .unwrap();
4553 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4554 let thread = agent.read_with(cx, |agent, _| {
4555 agent.sessions.get(&session_id).unwrap().thread.clone()
4556 });
4557 let model = Arc::new(FakeLanguageModel::default());
4558
4559 // Ensure empty threads are not saved, even if they get mutated.
4560 thread.update(cx, |thread, cx| {
4561 thread.set_model(model.clone(), cx);
4562 });
4563 cx.run_until_parked();
4564
4565 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4566 cx.run_until_parked();
4567 model.send_last_completion_stream_text_chunk("spawning subagent");
4568 let subagent_tool_input = SpawnAgentToolInput {
4569 label: "label".to_string(),
4570 message: "subagent task prompt".to_string(),
4571 session_id: None,
4572 };
4573 let subagent_tool_use = LanguageModelToolUse {
4574 id: "subagent_1".into(),
4575 name: SpawnAgentTool::NAME.into(),
4576 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4577 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4578 is_input_complete: true,
4579 thought_signature: None,
4580 };
4581 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4582 subagent_tool_use,
4583 ));
4584 model.end_last_completion_stream();
4585
4586 cx.run_until_parked();
4587
4588 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4589 thread
4590 .running_subagent_ids(cx)
4591 .get(0)
4592 .expect("subagent thread should be running")
4593 .clone()
4594 });
4595
4596 let subagent_thread = agent.read_with(cx, |agent, _cx| {
4597 agent
4598 .sessions
4599 .get(&subagent_session_id)
4600 .expect("subagent session should exist")
4601 .acp_thread
4602 .clone()
4603 });
4604
4605 model.send_last_completion_stream_text_chunk("subagent task response 1");
4606 model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
4607 text: "thinking more about the subagent task".into(),
4608 signature: None,
4609 });
4610 model.send_last_completion_stream_text_chunk("subagent task response 2");
4611 model.end_last_completion_stream();
4612
4613 cx.run_until_parked();
4614
4615 assert_eq!(
4616 subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4617 indoc! {"
4618 ## User
4619
4620 subagent task prompt
4621
4622 ## Assistant
4623
4624 subagent task response 1
4625
4626 <thinking>
4627 thinking more about the subagent task
4628 </thinking>
4629
4630 subagent task response 2
4631
4632 "}
4633 );
4634
4635 model.send_last_completion_stream_text_chunk("Response");
4636 model.end_last_completion_stream();
4637
4638 send.await.unwrap();
4639
4640 assert_eq!(
4641 acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4642 indoc! {r#"
4643 ## User
4644
4645 Prompt
4646
4647 ## Assistant
4648
4649 spawning subagent
4650
4651 **Tool Call: label**
4652 Status: Completed
4653
4654 subagent task response 1
4655
4656 subagent task response 2
4657
4658 ## Assistant
4659
4660 Response
4661
4662 "#},
4663 );
4664}
4665
4666#[gpui::test]
4667async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAppContext) {
4668 init_test(cx);
4669 cx.update(|cx| {
4670 LanguageModelRegistry::test(cx);
4671 });
4672 cx.update(|cx| {
4673 cx.update_flags(true, vec!["subagents".to_string()]);
4674 });
4675
4676 let fs = FakeFs::new(cx.executor());
4677 fs.insert_tree(
4678 "/",
4679 json!({
4680 "a": {
4681 "b.md": "Lorem"
4682 }
4683 }),
4684 )
4685 .await;
4686 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4687 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4688 let agent = NativeAgent::new(
4689 project.clone(),
4690 thread_store.clone(),
4691 Templates::new(),
4692 None,
4693 fs.clone(),
4694 &mut cx.to_async(),
4695 )
4696 .await
4697 .unwrap();
4698 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4699
4700 let acp_thread = cx
4701 .update(|cx| {
4702 connection
4703 .clone()
4704 .new_session(project.clone(), Path::new(""), cx)
4705 })
4706 .await
4707 .unwrap();
4708 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4709 let thread = agent.read_with(cx, |agent, _| {
4710 agent.sessions.get(&session_id).unwrap().thread.clone()
4711 });
4712 let model = Arc::new(FakeLanguageModel::default());
4713
4714 // Ensure empty threads are not saved, even if they get mutated.
4715 thread.update(cx, |thread, cx| {
4716 thread.set_model(model.clone(), cx);
4717 });
4718 cx.run_until_parked();
4719
4720 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4721 cx.run_until_parked();
4722 model.send_last_completion_stream_text_chunk("spawning subagent");
4723 let subagent_tool_input = SpawnAgentToolInput {
4724 label: "label".to_string(),
4725 message: "subagent task prompt".to_string(),
4726 session_id: None,
4727 };
4728 let subagent_tool_use = LanguageModelToolUse {
4729 id: "subagent_1".into(),
4730 name: SpawnAgentTool::NAME.into(),
4731 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4732 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4733 is_input_complete: true,
4734 thought_signature: None,
4735 };
4736 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4737 subagent_tool_use,
4738 ));
4739 model.end_last_completion_stream();
4740
4741 cx.run_until_parked();
4742
4743 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4744 thread
4745 .running_subagent_ids(cx)
4746 .get(0)
4747 .expect("subagent thread should be running")
4748 .clone()
4749 });
4750 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4751 agent
4752 .sessions
4753 .get(&subagent_session_id)
4754 .expect("subagent session should exist")
4755 .acp_thread
4756 .clone()
4757 });
4758
4759 // model.send_last_completion_stream_text_chunk("subagent task response");
4760 // model.end_last_completion_stream();
4761
4762 // cx.run_until_parked();
4763
4764 acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
4765
4766 cx.run_until_parked();
4767
4768 send.await.unwrap();
4769
4770 acp_thread.read_with(cx, |thread, cx| {
4771 assert_eq!(thread.status(), ThreadStatus::Idle);
4772 assert_eq!(
4773 thread.to_markdown(cx),
4774 indoc! {"
4775 ## User
4776
4777 Prompt
4778
4779 ## Assistant
4780
4781 spawning subagent
4782
4783 **Tool Call: label**
4784 Status: Canceled
4785
4786 "}
4787 );
4788 });
4789 subagent_acp_thread.read_with(cx, |thread, cx| {
4790 assert_eq!(thread.status(), ThreadStatus::Idle);
4791 assert_eq!(
4792 thread.to_markdown(cx),
4793 indoc! {"
4794 ## User
4795
4796 subagent task prompt
4797
4798 "}
4799 );
4800 });
4801}
4802
4803#[gpui::test]
4804async fn test_subagent_tool_resume_session(cx: &mut TestAppContext) {
4805 init_test(cx);
4806 cx.update(|cx| {
4807 LanguageModelRegistry::test(cx);
4808 });
4809 cx.update(|cx| {
4810 cx.update_flags(true, vec!["subagents".to_string()]);
4811 });
4812
4813 let fs = FakeFs::new(cx.executor());
4814 fs.insert_tree(
4815 "/",
4816 json!({
4817 "a": {
4818 "b.md": "Lorem"
4819 }
4820 }),
4821 )
4822 .await;
4823 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4824 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4825 let agent = NativeAgent::new(
4826 project.clone(),
4827 thread_store.clone(),
4828 Templates::new(),
4829 None,
4830 fs.clone(),
4831 &mut cx.to_async(),
4832 )
4833 .await
4834 .unwrap();
4835 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4836
4837 let acp_thread = cx
4838 .update(|cx| {
4839 connection
4840 .clone()
4841 .new_session(project.clone(), Path::new(""), cx)
4842 })
4843 .await
4844 .unwrap();
4845 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4846 let thread = agent.read_with(cx, |agent, _| {
4847 agent.sessions.get(&session_id).unwrap().thread.clone()
4848 });
4849 let model = Arc::new(FakeLanguageModel::default());
4850
4851 thread.update(cx, |thread, cx| {
4852 thread.set_model(model.clone(), cx);
4853 });
4854 cx.run_until_parked();
4855
4856 // === First turn: create subagent ===
4857 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx));
4858 cx.run_until_parked();
4859 model.send_last_completion_stream_text_chunk("spawning subagent");
4860 let subagent_tool_input = SpawnAgentToolInput {
4861 label: "initial task".to_string(),
4862 message: "do the first task".to_string(),
4863 session_id: None,
4864 };
4865 let subagent_tool_use = LanguageModelToolUse {
4866 id: "subagent_1".into(),
4867 name: SpawnAgentTool::NAME.into(),
4868 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4869 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4870 is_input_complete: true,
4871 thought_signature: None,
4872 };
4873 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4874 subagent_tool_use,
4875 ));
4876 model.end_last_completion_stream();
4877
4878 cx.run_until_parked();
4879
4880 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4881 thread
4882 .running_subagent_ids(cx)
4883 .get(0)
4884 .expect("subagent thread should be running")
4885 .clone()
4886 });
4887
4888 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4889 agent
4890 .sessions
4891 .get(&subagent_session_id)
4892 .expect("subagent session should exist")
4893 .acp_thread
4894 .clone()
4895 });
4896
4897 // Subagent responds
4898 model.send_last_completion_stream_text_chunk("first task response");
4899 model.end_last_completion_stream();
4900
4901 cx.run_until_parked();
4902
4903 // Parent model responds to complete first turn
4904 model.send_last_completion_stream_text_chunk("First response");
4905 model.end_last_completion_stream();
4906
4907 send.await.unwrap();
4908
4909 // Verify subagent is no longer running
4910 thread.read_with(cx, |thread, cx| {
4911 assert!(
4912 thread.running_subagent_ids(cx).is_empty(),
4913 "subagent should not be running after completion"
4914 );
4915 });
4916
4917 // === Second turn: resume subagent with session_id ===
4918 let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx));
4919 cx.run_until_parked();
4920 model.send_last_completion_stream_text_chunk("resuming subagent");
4921 let resume_tool_input = SpawnAgentToolInput {
4922 label: "follow-up task".to_string(),
4923 message: "do the follow-up task".to_string(),
4924 session_id: Some(subagent_session_id.clone()),
4925 };
4926 let resume_tool_use = LanguageModelToolUse {
4927 id: "subagent_2".into(),
4928 name: SpawnAgentTool::NAME.into(),
4929 raw_input: serde_json::to_string(&resume_tool_input).unwrap(),
4930 input: serde_json::to_value(&resume_tool_input).unwrap(),
4931 is_input_complete: true,
4932 thought_signature: None,
4933 };
4934 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use));
4935 model.end_last_completion_stream();
4936
4937 cx.run_until_parked();
4938
4939 // Subagent should be running again with the same session
4940 thread.read_with(cx, |thread, cx| {
4941 let running = thread.running_subagent_ids(cx);
4942 assert_eq!(running.len(), 1, "subagent should be running");
4943 assert_eq!(running[0], subagent_session_id, "should be same session");
4944 });
4945
4946 // Subagent responds to follow-up
4947 model.send_last_completion_stream_text_chunk("follow-up task response");
4948 model.end_last_completion_stream();
4949
4950 cx.run_until_parked();
4951
4952 // Parent model responds to complete second turn
4953 model.send_last_completion_stream_text_chunk("Second response");
4954 model.end_last_completion_stream();
4955
4956 send2.await.unwrap();
4957
4958 // Verify subagent is no longer running
4959 thread.read_with(cx, |thread, cx| {
4960 assert!(
4961 thread.running_subagent_ids(cx).is_empty(),
4962 "subagent should not be running after resume completion"
4963 );
4964 });
4965
4966 // Verify the subagent's acp thread has both conversation turns
4967 assert_eq!(
4968 subagent_acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4969 indoc! {"
4970 ## User
4971
4972 do the first task
4973
4974 ## Assistant
4975
4976 first task response
4977
4978 ## User
4979
4980 do the follow-up task
4981
4982 ## Assistant
4983
4984 follow-up task response
4985
4986 "}
4987 );
4988}
4989
4990#[gpui::test]
4991async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
4992 init_test(cx);
4993
4994 cx.update(|cx| {
4995 cx.update_flags(true, vec!["subagents".to_string()]);
4996 });
4997
4998 let fs = FakeFs::new(cx.executor());
4999 fs.insert_tree(path!("/test"), json!({})).await;
5000 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
5001 let project_context = cx.new(|_cx| ProjectContext::default());
5002 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5003 let context_server_registry =
5004 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5005 let model = Arc::new(FakeLanguageModel::default());
5006
5007 let environment = Rc::new(cx.update(|cx| {
5008 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
5009 }));
5010
5011 let thread = cx.new(|cx| {
5012 let mut thread = Thread::new(
5013 project.clone(),
5014 project_context,
5015 context_server_registry,
5016 Templates::new(),
5017 Some(model),
5018 cx,
5019 );
5020 thread.add_default_tools(environment, cx);
5021 thread
5022 });
5023
5024 thread.read_with(cx, |thread, _| {
5025 assert!(
5026 thread.has_registered_tool(SpawnAgentTool::NAME),
5027 "subagent tool should be present when feature flag is enabled"
5028 );
5029 });
5030}
5031
5032#[gpui::test]
5033async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
5034 init_test(cx);
5035
5036 cx.update(|cx| {
5037 cx.update_flags(true, vec!["subagents".to_string()]);
5038 });
5039
5040 let fs = FakeFs::new(cx.executor());
5041 fs.insert_tree(path!("/test"), json!({})).await;
5042 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
5043 let project_context = cx.new(|_cx| ProjectContext::default());
5044 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5045 let context_server_registry =
5046 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5047 let model = Arc::new(FakeLanguageModel::default());
5048
5049 let parent_thread = cx.new(|cx| {
5050 Thread::new(
5051 project.clone(),
5052 project_context,
5053 context_server_registry,
5054 Templates::new(),
5055 Some(model.clone()),
5056 cx,
5057 )
5058 });
5059
5060 let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
5061 subagent_thread.read_with(cx, |subagent_thread, cx| {
5062 assert!(subagent_thread.is_subagent());
5063 assert_eq!(subagent_thread.depth(), 1);
5064 assert_eq!(
5065 subagent_thread.model().map(|model| model.id()),
5066 Some(model.id())
5067 );
5068 assert_eq!(
5069 subagent_thread.parent_thread_id(),
5070 Some(parent_thread.read(cx).id().clone())
5071 );
5072 });
5073}
5074
5075#[gpui::test]
5076async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
5077 init_test(cx);
5078
5079 cx.update(|cx| {
5080 cx.update_flags(true, vec!["subagents".to_string()]);
5081 });
5082
5083 let fs = FakeFs::new(cx.executor());
5084 fs.insert_tree(path!("/test"), json!({})).await;
5085 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
5086 let project_context = cx.new(|_cx| ProjectContext::default());
5087 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5088 let context_server_registry =
5089 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5090 let model = Arc::new(FakeLanguageModel::default());
5091 let environment = Rc::new(cx.update(|cx| {
5092 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
5093 }));
5094
5095 let deep_parent_thread = cx.new(|cx| {
5096 let mut thread = Thread::new(
5097 project.clone(),
5098 project_context,
5099 context_server_registry,
5100 Templates::new(),
5101 Some(model.clone()),
5102 cx,
5103 );
5104 thread.set_subagent_context(SubagentContext {
5105 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
5106 depth: MAX_SUBAGENT_DEPTH - 1,
5107 });
5108 thread
5109 });
5110 let deep_subagent_thread = cx.new(|cx| {
5111 let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
5112 thread.add_default_tools(environment, cx);
5113 thread
5114 });
5115
5116 deep_subagent_thread.read_with(cx, |thread, _| {
5117 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
5118 assert!(
5119 !thread.has_registered_tool(SpawnAgentTool::NAME),
5120 "subagent tool should not be present at max depth"
5121 );
5122 });
5123}
5124
5125#[gpui::test]
5126async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
5127 init_test(cx);
5128
5129 cx.update(|cx| {
5130 cx.update_flags(true, vec!["subagents".to_string()]);
5131 });
5132
5133 let fs = FakeFs::new(cx.executor());
5134 fs.insert_tree(path!("/test"), json!({})).await;
5135 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
5136 let project_context = cx.new(|_cx| ProjectContext::default());
5137 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5138 let context_server_registry =
5139 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5140 let model = Arc::new(FakeLanguageModel::default());
5141
5142 let parent = cx.new(|cx| {
5143 Thread::new(
5144 project.clone(),
5145 project_context.clone(),
5146 context_server_registry.clone(),
5147 Templates::new(),
5148 Some(model.clone()),
5149 cx,
5150 )
5151 });
5152
5153 let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
5154
5155 parent.update(cx, |thread, _cx| {
5156 thread.register_running_subagent(subagent.downgrade());
5157 });
5158
5159 subagent
5160 .update(cx, |thread, cx| {
5161 thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
5162 })
5163 .unwrap();
5164 cx.run_until_parked();
5165
5166 subagent.read_with(cx, |thread, _| {
5167 assert!(!thread.is_turn_complete(), "subagent should be running");
5168 });
5169
5170 parent.update(cx, |thread, cx| {
5171 thread.cancel(cx).detach();
5172 });
5173
5174 subagent.read_with(cx, |thread, _| {
5175 assert!(
5176 thread.is_turn_complete(),
5177 "subagent should be cancelled when parent cancels"
5178 );
5179 });
5180}
5181
5182#[gpui::test]
5183async fn test_subagent_context_window_warning(cx: &mut TestAppContext) {
5184 init_test(cx);
5185 cx.update(|cx| {
5186 LanguageModelRegistry::test(cx);
5187 });
5188 cx.update(|cx| {
5189 cx.update_flags(true, vec!["subagents".to_string()]);
5190 });
5191
5192 let fs = FakeFs::new(cx.executor());
5193 fs.insert_tree(
5194 "/",
5195 json!({
5196 "a": {
5197 "b.md": "Lorem"
5198 }
5199 }),
5200 )
5201 .await;
5202 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
5203 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5204 let agent = NativeAgent::new(
5205 project.clone(),
5206 thread_store.clone(),
5207 Templates::new(),
5208 None,
5209 fs.clone(),
5210 &mut cx.to_async(),
5211 )
5212 .await
5213 .unwrap();
5214 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5215
5216 let acp_thread = cx
5217 .update(|cx| {
5218 connection
5219 .clone()
5220 .new_session(project.clone(), Path::new(""), cx)
5221 })
5222 .await
5223 .unwrap();
5224 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5225 let thread = agent.read_with(cx, |agent, _| {
5226 agent.sessions.get(&session_id).unwrap().thread.clone()
5227 });
5228 let model = Arc::new(FakeLanguageModel::default());
5229
5230 thread.update(cx, |thread, cx| {
5231 thread.set_model(model.clone(), cx);
5232 });
5233 cx.run_until_parked();
5234
5235 // Start the parent turn
5236 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
5237 cx.run_until_parked();
5238 model.send_last_completion_stream_text_chunk("spawning subagent");
5239 let subagent_tool_input = SpawnAgentToolInput {
5240 label: "label".to_string(),
5241 message: "subagent task prompt".to_string(),
5242 session_id: None,
5243 };
5244 let subagent_tool_use = LanguageModelToolUse {
5245 id: "subagent_1".into(),
5246 name: SpawnAgentTool::NAME.into(),
5247 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5248 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5249 is_input_complete: true,
5250 thought_signature: None,
5251 };
5252 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5253 subagent_tool_use,
5254 ));
5255 model.end_last_completion_stream();
5256
5257 cx.run_until_parked();
5258
5259 // Verify subagent is running
5260 let subagent_session_id = thread.read_with(cx, |thread, cx| {
5261 thread
5262 .running_subagent_ids(cx)
5263 .get(0)
5264 .expect("subagent thread should be running")
5265 .clone()
5266 });
5267
5268 // Send a usage update that crosses the warning threshold (80% of 1,000,000)
5269 model.send_last_completion_stream_text_chunk("partial work");
5270 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5271 TokenUsage {
5272 input_tokens: 850_000,
5273 output_tokens: 0,
5274 cache_creation_input_tokens: 0,
5275 cache_read_input_tokens: 0,
5276 },
5277 ));
5278
5279 cx.run_until_parked();
5280
5281 // The subagent should no longer be running
5282 thread.read_with(cx, |thread, cx| {
5283 assert!(
5284 thread.running_subagent_ids(cx).is_empty(),
5285 "subagent should be stopped after context window warning"
5286 );
5287 });
5288
5289 // The parent model should get a new completion request to respond to the tool error
5290 model.send_last_completion_stream_text_chunk("Response after warning");
5291 model.end_last_completion_stream();
5292
5293 send.await.unwrap();
5294
5295 // Verify the parent thread shows the warning error in the tool call
5296 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5297 assert!(
5298 markdown.contains("nearing the end of its context window"),
5299 "tool output should contain context window warning message, got:\n{markdown}"
5300 );
5301 assert!(
5302 markdown.contains("Status: Failed"),
5303 "tool call should have Failed status, got:\n{markdown}"
5304 );
5305
5306 // Verify the subagent session still exists (can be resumed)
5307 agent.read_with(cx, |agent, _cx| {
5308 assert!(
5309 agent.sessions.contains_key(&subagent_session_id),
5310 "subagent session should still exist for potential resume"
5311 );
5312 });
5313}
5314
5315#[gpui::test]
5316async fn test_subagent_no_context_window_warning_when_already_at_warning(cx: &mut TestAppContext) {
5317 init_test(cx);
5318 cx.update(|cx| {
5319 LanguageModelRegistry::test(cx);
5320 });
5321 cx.update(|cx| {
5322 cx.update_flags(true, vec!["subagents".to_string()]);
5323 });
5324
5325 let fs = FakeFs::new(cx.executor());
5326 fs.insert_tree(
5327 "/",
5328 json!({
5329 "a": {
5330 "b.md": "Lorem"
5331 }
5332 }),
5333 )
5334 .await;
5335 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
5336 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5337 let agent = NativeAgent::new(
5338 project.clone(),
5339 thread_store.clone(),
5340 Templates::new(),
5341 None,
5342 fs.clone(),
5343 &mut cx.to_async(),
5344 )
5345 .await
5346 .unwrap();
5347 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5348
5349 let acp_thread = cx
5350 .update(|cx| {
5351 connection
5352 .clone()
5353 .new_session(project.clone(), Path::new(""), cx)
5354 })
5355 .await
5356 .unwrap();
5357 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5358 let thread = agent.read_with(cx, |agent, _| {
5359 agent.sessions.get(&session_id).unwrap().thread.clone()
5360 });
5361 let model = Arc::new(FakeLanguageModel::default());
5362
5363 thread.update(cx, |thread, cx| {
5364 thread.set_model(model.clone(), cx);
5365 });
5366 cx.run_until_parked();
5367
5368 // === First turn: create subagent, trigger context window warning ===
5369 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx));
5370 cx.run_until_parked();
5371 model.send_last_completion_stream_text_chunk("spawning subagent");
5372 let subagent_tool_input = SpawnAgentToolInput {
5373 label: "initial task".to_string(),
5374 message: "do the first task".to_string(),
5375 session_id: None,
5376 };
5377 let subagent_tool_use = LanguageModelToolUse {
5378 id: "subagent_1".into(),
5379 name: SpawnAgentTool::NAME.into(),
5380 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5381 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5382 is_input_complete: true,
5383 thought_signature: None,
5384 };
5385 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5386 subagent_tool_use,
5387 ));
5388 model.end_last_completion_stream();
5389
5390 cx.run_until_parked();
5391
5392 let subagent_session_id = thread.read_with(cx, |thread, cx| {
5393 thread
5394 .running_subagent_ids(cx)
5395 .get(0)
5396 .expect("subagent thread should be running")
5397 .clone()
5398 });
5399
5400 // Subagent sends a usage update that crosses the warning threshold.
5401 // This triggers Normal→Warning, stopping the subagent.
5402 model.send_last_completion_stream_text_chunk("partial work");
5403 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5404 TokenUsage {
5405 input_tokens: 850_000,
5406 output_tokens: 0,
5407 cache_creation_input_tokens: 0,
5408 cache_read_input_tokens: 0,
5409 },
5410 ));
5411
5412 cx.run_until_parked();
5413
5414 // Verify the first turn was stopped with a context window warning
5415 thread.read_with(cx, |thread, cx| {
5416 assert!(
5417 thread.running_subagent_ids(cx).is_empty(),
5418 "subagent should be stopped after context window warning"
5419 );
5420 });
5421
5422 // Parent model responds to complete first turn
5423 model.send_last_completion_stream_text_chunk("First response");
5424 model.end_last_completion_stream();
5425
5426 send.await.unwrap();
5427
5428 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5429 assert!(
5430 markdown.contains("nearing the end of its context window"),
5431 "first turn should have context window warning, got:\n{markdown}"
5432 );
5433
5434 // === Second turn: resume the same subagent (now at Warning level) ===
5435 let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx));
5436 cx.run_until_parked();
5437 model.send_last_completion_stream_text_chunk("resuming subagent");
5438 let resume_tool_input = SpawnAgentToolInput {
5439 label: "follow-up task".to_string(),
5440 message: "do the follow-up task".to_string(),
5441 session_id: Some(subagent_session_id.clone()),
5442 };
5443 let resume_tool_use = LanguageModelToolUse {
5444 id: "subagent_2".into(),
5445 name: SpawnAgentTool::NAME.into(),
5446 raw_input: serde_json::to_string(&resume_tool_input).unwrap(),
5447 input: serde_json::to_value(&resume_tool_input).unwrap(),
5448 is_input_complete: true,
5449 thought_signature: None,
5450 };
5451 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use));
5452 model.end_last_completion_stream();
5453
5454 cx.run_until_parked();
5455
5456 // Subagent responds with tokens still at warning level (no worse).
5457 // Since ratio_before_prompt was already Warning, this should NOT
5458 // trigger the context window warning again.
5459 model.send_last_completion_stream_text_chunk("follow-up task response");
5460 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5461 TokenUsage {
5462 input_tokens: 870_000,
5463 output_tokens: 0,
5464 cache_creation_input_tokens: 0,
5465 cache_read_input_tokens: 0,
5466 },
5467 ));
5468 model.end_last_completion_stream();
5469
5470 cx.run_until_parked();
5471
5472 // Parent model responds to complete second turn
5473 model.send_last_completion_stream_text_chunk("Second response");
5474 model.end_last_completion_stream();
5475
5476 send2.await.unwrap();
5477
5478 // The resumed subagent should have completed normally since the ratio
5479 // didn't transition (it was Warning before and stayed at Warning)
5480 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5481 assert!(
5482 markdown.contains("follow-up task response"),
5483 "resumed subagent should complete normally when already at warning, got:\n{markdown}"
5484 );
5485 // The second tool call should NOT have a context window warning
5486 let second_tool_pos = markdown
5487 .find("follow-up task")
5488 .expect("should find follow-up tool call");
5489 let after_second_tool = &markdown[second_tool_pos..];
5490 assert!(
5491 !after_second_tool.contains("nearing the end of its context window"),
5492 "should NOT contain context window warning for resumed subagent at same level, got:\n{after_second_tool}"
5493 );
5494}
5495
5496#[gpui::test]
5497async fn test_subagent_error_propagation(cx: &mut TestAppContext) {
5498 init_test(cx);
5499 cx.update(|cx| {
5500 LanguageModelRegistry::test(cx);
5501 });
5502 cx.update(|cx| {
5503 cx.update_flags(true, vec!["subagents".to_string()]);
5504 });
5505
5506 let fs = FakeFs::new(cx.executor());
5507 fs.insert_tree(
5508 "/",
5509 json!({
5510 "a": {
5511 "b.md": "Lorem"
5512 }
5513 }),
5514 )
5515 .await;
5516 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
5517 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5518 let agent = NativeAgent::new(
5519 project.clone(),
5520 thread_store.clone(),
5521 Templates::new(),
5522 None,
5523 fs.clone(),
5524 &mut cx.to_async(),
5525 )
5526 .await
5527 .unwrap();
5528 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5529
5530 let acp_thread = cx
5531 .update(|cx| {
5532 connection
5533 .clone()
5534 .new_session(project.clone(), Path::new(""), cx)
5535 })
5536 .await
5537 .unwrap();
5538 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5539 let thread = agent.read_with(cx, |agent, _| {
5540 agent.sessions.get(&session_id).unwrap().thread.clone()
5541 });
5542 let model = Arc::new(FakeLanguageModel::default());
5543
5544 thread.update(cx, |thread, cx| {
5545 thread.set_model(model.clone(), cx);
5546 });
5547 cx.run_until_parked();
5548
5549 // Start the parent turn
5550 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
5551 cx.run_until_parked();
5552 model.send_last_completion_stream_text_chunk("spawning subagent");
5553 let subagent_tool_input = SpawnAgentToolInput {
5554 label: "label".to_string(),
5555 message: "subagent task prompt".to_string(),
5556 session_id: None,
5557 };
5558 let subagent_tool_use = LanguageModelToolUse {
5559 id: "subagent_1".into(),
5560 name: SpawnAgentTool::NAME.into(),
5561 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5562 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5563 is_input_complete: true,
5564 thought_signature: None,
5565 };
5566 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5567 subagent_tool_use,
5568 ));
5569 model.end_last_completion_stream();
5570
5571 cx.run_until_parked();
5572
5573 // Verify subagent is running
5574 thread.read_with(cx, |thread, cx| {
5575 assert!(
5576 !thread.running_subagent_ids(cx).is_empty(),
5577 "subagent should be running"
5578 );
5579 });
5580
5581 // The subagent's model returns a non-retryable error
5582 model.send_last_completion_stream_error(LanguageModelCompletionError::PromptTooLarge {
5583 tokens: None,
5584 });
5585
5586 cx.run_until_parked();
5587
5588 // The subagent should no longer be running
5589 thread.read_with(cx, |thread, cx| {
5590 assert!(
5591 thread.running_subagent_ids(cx).is_empty(),
5592 "subagent should not be running after error"
5593 );
5594 });
5595
5596 // The parent model should get a new completion request to respond to the tool error
5597 model.send_last_completion_stream_text_chunk("Response after error");
5598 model.end_last_completion_stream();
5599
5600 send.await.unwrap();
5601
5602 // Verify the parent thread shows the error in the tool call
5603 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5604 assert!(
5605 markdown.contains("Status: Failed"),
5606 "tool call should have Failed status after model error, got:\n{markdown}"
5607 );
5608}
5609
5610#[gpui::test]
5611async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5612 init_test(cx);
5613
5614 let fs = FakeFs::new(cx.executor());
5615 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5616 .await;
5617 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5618
5619 cx.update(|cx| {
5620 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5621 settings.tool_permissions.tools.insert(
5622 EditFileTool::NAME.into(),
5623 agent_settings::ToolRules {
5624 default: Some(settings::ToolPermissionMode::Allow),
5625 always_allow: vec![],
5626 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5627 always_confirm: vec![],
5628 invalid_patterns: vec![],
5629 },
5630 );
5631 agent_settings::AgentSettings::override_global(settings, cx);
5632 });
5633
5634 let context_server_registry =
5635 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5636 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5637 let templates = crate::Templates::new();
5638 let thread = cx.new(|cx| {
5639 crate::Thread::new(
5640 project.clone(),
5641 cx.new(|_cx| prompt_store::ProjectContext::default()),
5642 context_server_registry,
5643 templates.clone(),
5644 None,
5645 cx,
5646 )
5647 });
5648
5649 #[allow(clippy::arc_with_non_send_sync)]
5650 let tool = Arc::new(crate::EditFileTool::new(
5651 project.clone(),
5652 thread.downgrade(),
5653 language_registry,
5654 templates,
5655 ));
5656 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5657
5658 let task = cx.update(|cx| {
5659 tool.run(
5660 ToolInput::resolved(crate::EditFileToolInput {
5661 display_description: "Edit sensitive file".to_string(),
5662 path: "root/sensitive_config.txt".into(),
5663 mode: crate::EditFileMode::Edit,
5664 }),
5665 event_stream,
5666 cx,
5667 )
5668 });
5669
5670 let result = task.await;
5671 assert!(result.is_err(), "expected edit to be blocked");
5672 assert!(
5673 result.unwrap_err().to_string().contains("blocked"),
5674 "error should mention the edit was blocked"
5675 );
5676}
5677
5678#[gpui::test]
5679async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5680 init_test(cx);
5681
5682 let fs = FakeFs::new(cx.executor());
5683 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5684 .await;
5685 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5686
5687 cx.update(|cx| {
5688 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5689 settings.tool_permissions.tools.insert(
5690 DeletePathTool::NAME.into(),
5691 agent_settings::ToolRules {
5692 default: Some(settings::ToolPermissionMode::Allow),
5693 always_allow: vec![],
5694 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5695 always_confirm: vec![],
5696 invalid_patterns: vec![],
5697 },
5698 );
5699 agent_settings::AgentSettings::override_global(settings, cx);
5700 });
5701
5702 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5703
5704 #[allow(clippy::arc_with_non_send_sync)]
5705 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5706 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5707
5708 let task = cx.update(|cx| {
5709 tool.run(
5710 ToolInput::resolved(crate::DeletePathToolInput {
5711 path: "root/important_data.txt".to_string(),
5712 }),
5713 event_stream,
5714 cx,
5715 )
5716 });
5717
5718 let result = task.await;
5719 assert!(result.is_err(), "expected deletion to be blocked");
5720 assert!(
5721 result.unwrap_err().contains("blocked"),
5722 "error should mention the deletion was blocked"
5723 );
5724}
5725
5726#[gpui::test]
5727async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5728 init_test(cx);
5729
5730 let fs = FakeFs::new(cx.executor());
5731 fs.insert_tree(
5732 "/root",
5733 json!({
5734 "safe.txt": "content",
5735 "protected": {}
5736 }),
5737 )
5738 .await;
5739 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5740
5741 cx.update(|cx| {
5742 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5743 settings.tool_permissions.tools.insert(
5744 MovePathTool::NAME.into(),
5745 agent_settings::ToolRules {
5746 default: Some(settings::ToolPermissionMode::Allow),
5747 always_allow: vec![],
5748 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5749 always_confirm: vec![],
5750 invalid_patterns: vec![],
5751 },
5752 );
5753 agent_settings::AgentSettings::override_global(settings, cx);
5754 });
5755
5756 #[allow(clippy::arc_with_non_send_sync)]
5757 let tool = Arc::new(crate::MovePathTool::new(project));
5758 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5759
5760 let task = cx.update(|cx| {
5761 tool.run(
5762 ToolInput::resolved(crate::MovePathToolInput {
5763 source_path: "root/safe.txt".to_string(),
5764 destination_path: "root/protected/safe.txt".to_string(),
5765 }),
5766 event_stream,
5767 cx,
5768 )
5769 });
5770
5771 let result = task.await;
5772 assert!(
5773 result.is_err(),
5774 "expected move to be blocked due to destination path"
5775 );
5776 assert!(
5777 result.unwrap_err().contains("blocked"),
5778 "error should mention the move was blocked"
5779 );
5780}
5781
5782#[gpui::test]
5783async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5784 init_test(cx);
5785
5786 let fs = FakeFs::new(cx.executor());
5787 fs.insert_tree(
5788 "/root",
5789 json!({
5790 "secret.txt": "secret content",
5791 "public": {}
5792 }),
5793 )
5794 .await;
5795 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5796
5797 cx.update(|cx| {
5798 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5799 settings.tool_permissions.tools.insert(
5800 MovePathTool::NAME.into(),
5801 agent_settings::ToolRules {
5802 default: Some(settings::ToolPermissionMode::Allow),
5803 always_allow: vec![],
5804 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5805 always_confirm: vec![],
5806 invalid_patterns: vec![],
5807 },
5808 );
5809 agent_settings::AgentSettings::override_global(settings, cx);
5810 });
5811
5812 #[allow(clippy::arc_with_non_send_sync)]
5813 let tool = Arc::new(crate::MovePathTool::new(project));
5814 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5815
5816 let task = cx.update(|cx| {
5817 tool.run(
5818 ToolInput::resolved(crate::MovePathToolInput {
5819 source_path: "root/secret.txt".to_string(),
5820 destination_path: "root/public/not_secret.txt".to_string(),
5821 }),
5822 event_stream,
5823 cx,
5824 )
5825 });
5826
5827 let result = task.await;
5828 assert!(
5829 result.is_err(),
5830 "expected move to be blocked due to source path"
5831 );
5832 assert!(
5833 result.unwrap_err().contains("blocked"),
5834 "error should mention the move was blocked"
5835 );
5836}
5837
5838#[gpui::test]
5839async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5840 init_test(cx);
5841
5842 let fs = FakeFs::new(cx.executor());
5843 fs.insert_tree(
5844 "/root",
5845 json!({
5846 "confidential.txt": "confidential data",
5847 "dest": {}
5848 }),
5849 )
5850 .await;
5851 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5852
5853 cx.update(|cx| {
5854 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5855 settings.tool_permissions.tools.insert(
5856 CopyPathTool::NAME.into(),
5857 agent_settings::ToolRules {
5858 default: Some(settings::ToolPermissionMode::Allow),
5859 always_allow: vec![],
5860 always_deny: vec![
5861 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5862 ],
5863 always_confirm: vec![],
5864 invalid_patterns: vec![],
5865 },
5866 );
5867 agent_settings::AgentSettings::override_global(settings, cx);
5868 });
5869
5870 #[allow(clippy::arc_with_non_send_sync)]
5871 let tool = Arc::new(crate::CopyPathTool::new(project));
5872 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5873
5874 let task = cx.update(|cx| {
5875 tool.run(
5876 ToolInput::resolved(crate::CopyPathToolInput {
5877 source_path: "root/confidential.txt".to_string(),
5878 destination_path: "root/dest/copy.txt".to_string(),
5879 }),
5880 event_stream,
5881 cx,
5882 )
5883 });
5884
5885 let result = task.await;
5886 assert!(result.is_err(), "expected copy to be blocked");
5887 assert!(
5888 result.unwrap_err().contains("blocked"),
5889 "error should mention the copy was blocked"
5890 );
5891}
5892
5893#[gpui::test]
5894async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5895 init_test(cx);
5896
5897 let fs = FakeFs::new(cx.executor());
5898 fs.insert_tree(
5899 "/root",
5900 json!({
5901 "normal.txt": "normal content",
5902 "readonly": {
5903 "config.txt": "readonly content"
5904 }
5905 }),
5906 )
5907 .await;
5908 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5909
5910 cx.update(|cx| {
5911 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5912 settings.tool_permissions.tools.insert(
5913 SaveFileTool::NAME.into(),
5914 agent_settings::ToolRules {
5915 default: Some(settings::ToolPermissionMode::Allow),
5916 always_allow: vec![],
5917 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5918 always_confirm: vec![],
5919 invalid_patterns: vec![],
5920 },
5921 );
5922 agent_settings::AgentSettings::override_global(settings, cx);
5923 });
5924
5925 #[allow(clippy::arc_with_non_send_sync)]
5926 let tool = Arc::new(crate::SaveFileTool::new(project));
5927 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5928
5929 let task = cx.update(|cx| {
5930 tool.run(
5931 ToolInput::resolved(crate::SaveFileToolInput {
5932 paths: vec![
5933 std::path::PathBuf::from("root/normal.txt"),
5934 std::path::PathBuf::from("root/readonly/config.txt"),
5935 ],
5936 }),
5937 event_stream,
5938 cx,
5939 )
5940 });
5941
5942 let result = task.await;
5943 assert!(
5944 result.is_err(),
5945 "expected save to be blocked due to denied path"
5946 );
5947 assert!(
5948 result.unwrap_err().contains("blocked"),
5949 "error should mention the save was blocked"
5950 );
5951}
5952
5953#[gpui::test]
5954async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5955 init_test(cx);
5956
5957 let fs = FakeFs::new(cx.executor());
5958 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5959 .await;
5960 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5961
5962 cx.update(|cx| {
5963 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5964 settings.tool_permissions.tools.insert(
5965 SaveFileTool::NAME.into(),
5966 agent_settings::ToolRules {
5967 default: Some(settings::ToolPermissionMode::Allow),
5968 always_allow: vec![],
5969 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5970 always_confirm: vec![],
5971 invalid_patterns: vec![],
5972 },
5973 );
5974 agent_settings::AgentSettings::override_global(settings, cx);
5975 });
5976
5977 #[allow(clippy::arc_with_non_send_sync)]
5978 let tool = Arc::new(crate::SaveFileTool::new(project));
5979 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5980
5981 let task = cx.update(|cx| {
5982 tool.run(
5983 ToolInput::resolved(crate::SaveFileToolInput {
5984 paths: vec![std::path::PathBuf::from("root/config.secret")],
5985 }),
5986 event_stream,
5987 cx,
5988 )
5989 });
5990
5991 let result = task.await;
5992 assert!(result.is_err(), "expected save to be blocked");
5993 assert!(
5994 result.unwrap_err().contains("blocked"),
5995 "error should mention the save was blocked"
5996 );
5997}
5998
5999#[gpui::test]
6000async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
6001 init_test(cx);
6002
6003 cx.update(|cx| {
6004 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6005 settings.tool_permissions.tools.insert(
6006 WebSearchTool::NAME.into(),
6007 agent_settings::ToolRules {
6008 default: Some(settings::ToolPermissionMode::Allow),
6009 always_allow: vec![],
6010 always_deny: vec![
6011 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
6012 ],
6013 always_confirm: vec![],
6014 invalid_patterns: vec![],
6015 },
6016 );
6017 agent_settings::AgentSettings::override_global(settings, cx);
6018 });
6019
6020 #[allow(clippy::arc_with_non_send_sync)]
6021 let tool = Arc::new(crate::WebSearchTool);
6022 let (event_stream, _rx) = crate::ToolCallEventStream::test();
6023
6024 let input: crate::WebSearchToolInput =
6025 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
6026
6027 let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
6028
6029 let result = task.await;
6030 assert!(result.is_err(), "expected search to be blocked");
6031 match result.unwrap_err() {
6032 crate::WebSearchToolOutput::Error { error } => {
6033 assert!(
6034 error.contains("blocked"),
6035 "error should mention the search was blocked"
6036 );
6037 }
6038 other => panic!("expected Error variant, got: {other:?}"),
6039 }
6040}
6041
6042#[gpui::test]
6043async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
6044 init_test(cx);
6045
6046 let fs = FakeFs::new(cx.executor());
6047 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
6048 .await;
6049 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
6050
6051 cx.update(|cx| {
6052 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6053 settings.tool_permissions.tools.insert(
6054 EditFileTool::NAME.into(),
6055 agent_settings::ToolRules {
6056 default: Some(settings::ToolPermissionMode::Confirm),
6057 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
6058 always_deny: vec![],
6059 always_confirm: vec![],
6060 invalid_patterns: vec![],
6061 },
6062 );
6063 agent_settings::AgentSettings::override_global(settings, cx);
6064 });
6065
6066 let context_server_registry =
6067 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
6068 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
6069 let templates = crate::Templates::new();
6070 let thread = cx.new(|cx| {
6071 crate::Thread::new(
6072 project.clone(),
6073 cx.new(|_cx| prompt_store::ProjectContext::default()),
6074 context_server_registry,
6075 templates.clone(),
6076 None,
6077 cx,
6078 )
6079 });
6080
6081 #[allow(clippy::arc_with_non_send_sync)]
6082 let tool = Arc::new(crate::EditFileTool::new(
6083 project,
6084 thread.downgrade(),
6085 language_registry,
6086 templates,
6087 ));
6088 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
6089
6090 let _task = cx.update(|cx| {
6091 tool.run(
6092 ToolInput::resolved(crate::EditFileToolInput {
6093 display_description: "Edit README".to_string(),
6094 path: "root/README.md".into(),
6095 mode: crate::EditFileMode::Edit,
6096 }),
6097 event_stream,
6098 cx,
6099 )
6100 });
6101
6102 cx.run_until_parked();
6103
6104 let event = rx.try_next();
6105 assert!(
6106 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
6107 "expected no authorization request for allowed .md file"
6108 );
6109}
6110
6111#[gpui::test]
6112async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut TestAppContext) {
6113 init_test(cx);
6114
6115 let fs = FakeFs::new(cx.executor());
6116 fs.insert_tree(
6117 "/root",
6118 json!({
6119 ".zed": {
6120 "settings.json": "{}"
6121 },
6122 "README.md": "# Hello"
6123 }),
6124 )
6125 .await;
6126 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
6127
6128 cx.update(|cx| {
6129 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6130 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
6131 agent_settings::AgentSettings::override_global(settings, cx);
6132 });
6133
6134 let context_server_registry =
6135 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
6136 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
6137 let templates = crate::Templates::new();
6138 let thread = cx.new(|cx| {
6139 crate::Thread::new(
6140 project.clone(),
6141 cx.new(|_cx| prompt_store::ProjectContext::default()),
6142 context_server_registry,
6143 templates.clone(),
6144 None,
6145 cx,
6146 )
6147 });
6148
6149 #[allow(clippy::arc_with_non_send_sync)]
6150 let tool = Arc::new(crate::EditFileTool::new(
6151 project,
6152 thread.downgrade(),
6153 language_registry,
6154 templates,
6155 ));
6156
6157 // Editing a file inside .zed/ should still prompt even with global default: allow,
6158 // because local settings paths are sensitive and require confirmation regardless.
6159 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
6160 let _task = cx.update(|cx| {
6161 tool.run(
6162 ToolInput::resolved(crate::EditFileToolInput {
6163 display_description: "Edit local settings".to_string(),
6164 path: "root/.zed/settings.json".into(),
6165 mode: crate::EditFileMode::Edit,
6166 }),
6167 event_stream,
6168 cx,
6169 )
6170 });
6171
6172 let _update = rx.expect_update_fields().await;
6173 let _auth = rx.expect_authorization().await;
6174}
6175
6176#[gpui::test]
6177async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
6178 init_test(cx);
6179
6180 cx.update(|cx| {
6181 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6182 settings.tool_permissions.tools.insert(
6183 FetchTool::NAME.into(),
6184 agent_settings::ToolRules {
6185 default: Some(settings::ToolPermissionMode::Allow),
6186 always_allow: vec![],
6187 always_deny: vec![
6188 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
6189 ],
6190 always_confirm: vec![],
6191 invalid_patterns: vec![],
6192 },
6193 );
6194 agent_settings::AgentSettings::override_global(settings, cx);
6195 });
6196
6197 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
6198
6199 #[allow(clippy::arc_with_non_send_sync)]
6200 let tool = Arc::new(crate::FetchTool::new(http_client));
6201 let (event_stream, _rx) = crate::ToolCallEventStream::test();
6202
6203 let input: crate::FetchToolInput =
6204 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
6205
6206 let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
6207
6208 let result = task.await;
6209 assert!(result.is_err(), "expected fetch to be blocked");
6210 assert!(
6211 result.unwrap_err().contains("blocked"),
6212 "error should mention the fetch was blocked"
6213 );
6214}
6215
6216#[gpui::test]
6217async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
6218 init_test(cx);
6219
6220 cx.update(|cx| {
6221 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6222 settings.tool_permissions.tools.insert(
6223 FetchTool::NAME.into(),
6224 agent_settings::ToolRules {
6225 default: Some(settings::ToolPermissionMode::Confirm),
6226 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
6227 always_deny: vec![],
6228 always_confirm: vec![],
6229 invalid_patterns: vec![],
6230 },
6231 );
6232 agent_settings::AgentSettings::override_global(settings, cx);
6233 });
6234
6235 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
6236
6237 #[allow(clippy::arc_with_non_send_sync)]
6238 let tool = Arc::new(crate::FetchTool::new(http_client));
6239 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
6240
6241 let input: crate::FetchToolInput =
6242 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
6243
6244 let _task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
6245
6246 cx.run_until_parked();
6247
6248 let event = rx.try_next();
6249 assert!(
6250 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
6251 "expected no authorization request for allowed docs.rs URL"
6252 );
6253}
6254
6255#[gpui::test]
6256async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
6257 init_test(cx);
6258 always_allow_tools(cx);
6259
6260 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
6261 let fake_model = model.as_fake();
6262
6263 // Add a tool so we can simulate tool calls
6264 thread.update(cx, |thread, _cx| {
6265 thread.add_tool(EchoTool);
6266 });
6267
6268 // Start a turn by sending a message
6269 let mut events = thread
6270 .update(cx, |thread, cx| {
6271 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
6272 })
6273 .unwrap();
6274 cx.run_until_parked();
6275
6276 // Simulate the model making a tool call
6277 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
6278 LanguageModelToolUse {
6279 id: "tool_1".into(),
6280 name: "echo".into(),
6281 raw_input: r#"{"text": "hello"}"#.into(),
6282 input: json!({"text": "hello"}),
6283 is_input_complete: true,
6284 thought_signature: None,
6285 },
6286 ));
6287 fake_model
6288 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
6289
6290 // Signal that a message is queued before ending the stream
6291 thread.update(cx, |thread, _cx| {
6292 thread.set_has_queued_message(true);
6293 });
6294
6295 // Now end the stream - tool will run, and the boundary check should see the queue
6296 fake_model.end_last_completion_stream();
6297
6298 // Collect all events until the turn stops
6299 let all_events = collect_events_until_stop(&mut events, cx).await;
6300
6301 // Verify we received the tool call event
6302 let tool_call_ids: Vec<_> = all_events
6303 .iter()
6304 .filter_map(|e| match e {
6305 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
6306 _ => None,
6307 })
6308 .collect();
6309 assert_eq!(
6310 tool_call_ids,
6311 vec!["tool_1"],
6312 "Should have received a tool call event for our echo tool"
6313 );
6314
6315 // The turn should have stopped with EndTurn
6316 let stop_reasons = stop_events(all_events);
6317 assert_eq!(
6318 stop_reasons,
6319 vec![acp::StopReason::EndTurn],
6320 "Turn should have ended after tool completion due to queued message"
6321 );
6322
6323 // Verify the queued message flag is still set
6324 thread.update(cx, |thread, _cx| {
6325 assert!(
6326 thread.has_queued_message(),
6327 "Should still have queued message flag set"
6328 );
6329 });
6330
6331 // Thread should be idle now
6332 thread.update(cx, |thread, _cx| {
6333 assert!(
6334 thread.is_turn_complete(),
6335 "Thread should not be running after turn ends"
6336 );
6337 });
6338}
6339
6340#[gpui::test]
6341async fn test_streaming_tool_error_breaks_stream_loop_immediately(cx: &mut TestAppContext) {
6342 init_test(cx);
6343 always_allow_tools(cx);
6344
6345 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
6346 let fake_model = model.as_fake();
6347
6348 thread.update(cx, |thread, _cx| {
6349 thread.add_tool(StreamingFailingEchoTool {
6350 receive_chunks_until_failure: 1,
6351 });
6352 });
6353
6354 let _events = thread
6355 .update(cx, |thread, cx| {
6356 thread.send(
6357 UserMessageId::new(),
6358 ["Use the streaming_failing_echo tool"],
6359 cx,
6360 )
6361 })
6362 .unwrap();
6363 cx.run_until_parked();
6364
6365 let tool_use = LanguageModelToolUse {
6366 id: "call_1".into(),
6367 name: StreamingFailingEchoTool::NAME.into(),
6368 raw_input: "hello".into(),
6369 input: json!({}),
6370 is_input_complete: false,
6371 thought_signature: None,
6372 };
6373
6374 fake_model
6375 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
6376
6377 cx.run_until_parked();
6378
6379 let completions = fake_model.pending_completions();
6380 let last_completion = completions.last().unwrap();
6381
6382 assert_eq!(
6383 last_completion.messages[1..],
6384 vec![
6385 LanguageModelRequestMessage {
6386 role: Role::User,
6387 content: vec!["Use the streaming_failing_echo tool".into()],
6388 cache: false,
6389 reasoning_details: None,
6390 },
6391 LanguageModelRequestMessage {
6392 role: Role::Assistant,
6393 content: vec![language_model::MessageContent::ToolUse(tool_use.clone())],
6394 cache: false,
6395 reasoning_details: None,
6396 },
6397 LanguageModelRequestMessage {
6398 role: Role::User,
6399 content: vec![language_model::MessageContent::ToolResult(
6400 LanguageModelToolResult {
6401 tool_use_id: tool_use.id.clone(),
6402 tool_name: tool_use.name,
6403 is_error: true,
6404 content: "failed".into(),
6405 output: Some("failed".into()),
6406 }
6407 )],
6408 cache: true,
6409 reasoning_details: None,
6410 },
6411 ]
6412 );
6413}
6414
6415#[gpui::test]
6416async fn test_streaming_tool_error_waits_for_prior_tools_to_complete(cx: &mut TestAppContext) {
6417 init_test(cx);
6418 always_allow_tools(cx);
6419
6420 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
6421 let fake_model = model.as_fake();
6422
6423 let (complete_streaming_echo_tool_call_tx, complete_streaming_echo_tool_call_rx) =
6424 oneshot::channel();
6425
6426 thread.update(cx, |thread, _cx| {
6427 thread.add_tool(
6428 StreamingEchoTool::new().with_wait_until_complete(complete_streaming_echo_tool_call_rx),
6429 );
6430 thread.add_tool(StreamingFailingEchoTool {
6431 receive_chunks_until_failure: 1,
6432 });
6433 });
6434
6435 let _events = thread
6436 .update(cx, |thread, cx| {
6437 thread.send(
6438 UserMessageId::new(),
6439 ["Use the streaming_echo tool and the streaming_failing_echo tool"],
6440 cx,
6441 )
6442 })
6443 .unwrap();
6444 cx.run_until_parked();
6445
6446 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
6447 LanguageModelToolUse {
6448 id: "call_1".into(),
6449 name: StreamingEchoTool::NAME.into(),
6450 raw_input: "hello".into(),
6451 input: json!({ "text": "hello" }),
6452 is_input_complete: false,
6453 thought_signature: None,
6454 },
6455 ));
6456 let first_tool_use = LanguageModelToolUse {
6457 id: "call_1".into(),
6458 name: StreamingEchoTool::NAME.into(),
6459 raw_input: "hello world".into(),
6460 input: json!({ "text": "hello world" }),
6461 is_input_complete: true,
6462 thought_signature: None,
6463 };
6464 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
6465 first_tool_use.clone(),
6466 ));
6467 let second_tool_use = LanguageModelToolUse {
6468 name: StreamingFailingEchoTool::NAME.into(),
6469 raw_input: "hello".into(),
6470 input: json!({ "text": "hello" }),
6471 is_input_complete: false,
6472 thought_signature: None,
6473 id: "call_2".into(),
6474 };
6475 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
6476 second_tool_use.clone(),
6477 ));
6478
6479 cx.run_until_parked();
6480
6481 complete_streaming_echo_tool_call_tx.send(()).unwrap();
6482
6483 cx.run_until_parked();
6484
6485 let completions = fake_model.pending_completions();
6486 let last_completion = completions.last().unwrap();
6487
6488 assert_eq!(
6489 last_completion.messages[1..],
6490 vec![
6491 LanguageModelRequestMessage {
6492 role: Role::User,
6493 content: vec![
6494 "Use the streaming_echo tool and the streaming_failing_echo tool".into()
6495 ],
6496 cache: false,
6497 reasoning_details: None,
6498 },
6499 LanguageModelRequestMessage {
6500 role: Role::Assistant,
6501 content: vec![
6502 language_model::MessageContent::ToolUse(first_tool_use.clone()),
6503 language_model::MessageContent::ToolUse(second_tool_use.clone())
6504 ],
6505 cache: false,
6506 reasoning_details: None,
6507 },
6508 LanguageModelRequestMessage {
6509 role: Role::User,
6510 content: vec![
6511 language_model::MessageContent::ToolResult(LanguageModelToolResult {
6512 tool_use_id: second_tool_use.id.clone(),
6513 tool_name: second_tool_use.name,
6514 is_error: true,
6515 content: "failed".into(),
6516 output: Some("failed".into()),
6517 }),
6518 language_model::MessageContent::ToolResult(LanguageModelToolResult {
6519 tool_use_id: first_tool_use.id.clone(),
6520 tool_name: first_tool_use.name,
6521 is_error: false,
6522 content: "hello world".into(),
6523 output: Some("hello world".into()),
6524 }),
6525 ],
6526 cache: true,
6527 reasoning_details: None,
6528 },
6529 ]
6530 );
6531}