1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use fs::{FakeFs, Fs};
11use futures::{
12 FutureExt as _, StreamExt,
13 channel::{
14 mpsc::{self, UnboundedReceiver},
15 oneshot,
16 },
17 future::{Fuse, Shared},
18};
19use gpui::{
20 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
21 http_client::FakeHttpClient,
22};
23use indoc::indoc;
24
25use language_model::{
26 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
27 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
28 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
29 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
30};
31use pretty_assertions::assert_eq;
32use project::{
33 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
34};
35use prompt_store::ProjectContext;
36use reqwest_client::ReqwestClient;
37use schemars::JsonSchema;
38use serde::{Deserialize, Serialize};
39use serde_json::json;
40use settings::{Settings, SettingsStore};
41use std::{
42 path::Path,
43 pin::Pin,
44 rc::Rc,
45 sync::{
46 Arc,
47 atomic::{AtomicBool, Ordering},
48 },
49 time::Duration,
50};
51use util::path;
52
53mod test_tools;
54use test_tools::*;
55
56fn init_test(cx: &mut TestAppContext) {
57 cx.update(|cx| {
58 let settings_store = SettingsStore::test(cx);
59 cx.set_global(settings_store);
60 });
61}
62
63struct FakeTerminalHandle {
64 killed: Arc<AtomicBool>,
65 stopped_by_user: Arc<AtomicBool>,
66 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
67 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
68 output: acp::TerminalOutputResponse,
69 id: acp::TerminalId,
70}
71
72impl FakeTerminalHandle {
73 fn new_never_exits(cx: &mut App) -> Self {
74 let killed = Arc::new(AtomicBool::new(false));
75 let stopped_by_user = Arc::new(AtomicBool::new(false));
76
77 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
78
79 let wait_for_exit = cx
80 .spawn(async move |_cx| {
81 // Wait for the exit signal (sent when kill() is called)
82 let _ = exit_receiver.await;
83 acp::TerminalExitStatus::new()
84 })
85 .shared();
86
87 Self {
88 killed,
89 stopped_by_user,
90 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
91 wait_for_exit,
92 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
93 id: acp::TerminalId::new("fake_terminal".to_string()),
94 }
95 }
96
97 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
98 let killed = Arc::new(AtomicBool::new(false));
99 let stopped_by_user = Arc::new(AtomicBool::new(false));
100 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
101
102 let wait_for_exit = cx
103 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
104 .shared();
105
106 Self {
107 killed,
108 stopped_by_user,
109 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
110 wait_for_exit,
111 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
112 id: acp::TerminalId::new("fake_terminal".to_string()),
113 }
114 }
115
116 fn was_killed(&self) -> bool {
117 self.killed.load(Ordering::SeqCst)
118 }
119
120 fn set_stopped_by_user(&self, stopped: bool) {
121 self.stopped_by_user.store(stopped, Ordering::SeqCst);
122 }
123
124 fn signal_exit(&self) {
125 if let Some(sender) = self.exit_sender.borrow_mut().take() {
126 let _ = sender.send(());
127 }
128 }
129}
130
131impl crate::TerminalHandle for FakeTerminalHandle {
132 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
133 Ok(self.id.clone())
134 }
135
136 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
137 Ok(self.output.clone())
138 }
139
140 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
141 Ok(self.wait_for_exit.clone())
142 }
143
144 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
145 self.killed.store(true, Ordering::SeqCst);
146 self.signal_exit();
147 Ok(())
148 }
149
150 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
151 Ok(self.stopped_by_user.load(Ordering::SeqCst))
152 }
153}
154
155struct FakeThreadEnvironment {
156 handle: Rc<FakeTerminalHandle>,
157}
158
159impl crate::ThreadEnvironment for FakeThreadEnvironment {
160 fn create_terminal(
161 &self,
162 _command: String,
163 _cwd: Option<std::path::PathBuf>,
164 _output_byte_limit: Option<u64>,
165 _cx: &mut AsyncApp,
166 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
167 Task::ready(Ok(self.handle.clone() as Rc<dyn crate::TerminalHandle>))
168 }
169}
170
171/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
172struct MultiTerminalEnvironment {
173 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
174}
175
176impl MultiTerminalEnvironment {
177 fn new() -> Self {
178 Self {
179 handles: std::cell::RefCell::new(Vec::new()),
180 }
181 }
182
183 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
184 self.handles.borrow().clone()
185 }
186}
187
188impl crate::ThreadEnvironment for MultiTerminalEnvironment {
189 fn create_terminal(
190 &self,
191 _command: String,
192 _cwd: Option<std::path::PathBuf>,
193 _output_byte_limit: Option<u64>,
194 cx: &mut AsyncApp,
195 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
196 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
197 self.handles.borrow_mut().push(handle.clone());
198 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
199 }
200}
201
202fn always_allow_tools(cx: &mut TestAppContext) {
203 cx.update(|cx| {
204 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
205 settings.always_allow_tool_actions = true;
206 agent_settings::AgentSettings::override_global(settings, cx);
207 });
208}
209
210#[gpui::test]
211async fn test_echo(cx: &mut TestAppContext) {
212 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
213 let fake_model = model.as_fake();
214
215 let events = thread
216 .update(cx, |thread, cx| {
217 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
218 })
219 .unwrap();
220 cx.run_until_parked();
221 fake_model.send_last_completion_stream_text_chunk("Hello");
222 fake_model
223 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
224 fake_model.end_last_completion_stream();
225
226 let events = events.collect().await;
227 thread.update(cx, |thread, _cx| {
228 assert_eq!(
229 thread.last_message().unwrap().to_markdown(),
230 indoc! {"
231 ## Assistant
232
233 Hello
234 "}
235 )
236 });
237 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
238}
239
240#[gpui::test]
241async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
242 init_test(cx);
243 always_allow_tools(cx);
244
245 let fs = FakeFs::new(cx.executor());
246 let project = Project::test(fs, [], cx).await;
247
248 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
249 let environment = Rc::new(FakeThreadEnvironment {
250 handle: handle.clone(),
251 });
252
253 #[allow(clippy::arc_with_non_send_sync)]
254 let tool = Arc::new(crate::TerminalTool::new(project, environment));
255 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
256
257 let task = cx.update(|cx| {
258 tool.run(
259 crate::TerminalToolInput {
260 command: "sleep 1000".to_string(),
261 cd: ".".to_string(),
262 timeout_ms: Some(5),
263 },
264 event_stream,
265 cx,
266 )
267 });
268
269 let update = rx.expect_update_fields().await;
270 assert!(
271 update.content.iter().any(|blocks| {
272 blocks
273 .iter()
274 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
275 }),
276 "expected tool call update to include terminal content"
277 );
278
279 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
280
281 let deadline = std::time::Instant::now() + Duration::from_millis(500);
282 loop {
283 if let Some(result) = task_future.as_mut().now_or_never() {
284 let result = result.expect("terminal tool task should complete");
285
286 assert!(
287 handle.was_killed(),
288 "expected terminal handle to be killed on timeout"
289 );
290 assert!(
291 result.contains("partial output"),
292 "expected result to include terminal output, got: {result}"
293 );
294 return;
295 }
296
297 if std::time::Instant::now() >= deadline {
298 panic!("timed out waiting for terminal tool task to complete");
299 }
300
301 cx.run_until_parked();
302 cx.background_executor.timer(Duration::from_millis(1)).await;
303 }
304}
305
306#[gpui::test]
307#[ignore]
308async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
309 init_test(cx);
310 always_allow_tools(cx);
311
312 let fs = FakeFs::new(cx.executor());
313 let project = Project::test(fs, [], cx).await;
314
315 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
316 let environment = Rc::new(FakeThreadEnvironment {
317 handle: handle.clone(),
318 });
319
320 #[allow(clippy::arc_with_non_send_sync)]
321 let tool = Arc::new(crate::TerminalTool::new(project, environment));
322 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
323
324 let _task = cx.update(|cx| {
325 tool.run(
326 crate::TerminalToolInput {
327 command: "sleep 1000".to_string(),
328 cd: ".".to_string(),
329 timeout_ms: None,
330 },
331 event_stream,
332 cx,
333 )
334 });
335
336 let update = rx.expect_update_fields().await;
337 assert!(
338 update.content.iter().any(|blocks| {
339 blocks
340 .iter()
341 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
342 }),
343 "expected tool call update to include terminal content"
344 );
345
346 smol::Timer::after(Duration::from_millis(25)).await;
347
348 assert!(
349 !handle.was_killed(),
350 "did not expect terminal handle to be killed without a timeout"
351 );
352}
353
354#[gpui::test]
355async fn test_thinking(cx: &mut TestAppContext) {
356 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
357 let fake_model = model.as_fake();
358
359 let events = thread
360 .update(cx, |thread, cx| {
361 thread.send(
362 UserMessageId::new(),
363 [indoc! {"
364 Testing:
365
366 Generate a thinking step where you just think the word 'Think',
367 and have your final answer be 'Hello'
368 "}],
369 cx,
370 )
371 })
372 .unwrap();
373 cx.run_until_parked();
374 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
375 text: "Think".to_string(),
376 signature: None,
377 });
378 fake_model.send_last_completion_stream_text_chunk("Hello");
379 fake_model
380 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
381 fake_model.end_last_completion_stream();
382
383 let events = events.collect().await;
384 thread.update(cx, |thread, _cx| {
385 assert_eq!(
386 thread.last_message().unwrap().to_markdown(),
387 indoc! {"
388 ## Assistant
389
390 <think>Think</think>
391 Hello
392 "}
393 )
394 });
395 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
396}
397
398#[gpui::test]
399async fn test_system_prompt(cx: &mut TestAppContext) {
400 let ThreadTest {
401 model,
402 thread,
403 project_context,
404 ..
405 } = setup(cx, TestModel::Fake).await;
406 let fake_model = model.as_fake();
407
408 project_context.update(cx, |project_context, _cx| {
409 project_context.shell = "test-shell".into()
410 });
411 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
412 thread
413 .update(cx, |thread, cx| {
414 thread.send(UserMessageId::new(), ["abc"], cx)
415 })
416 .unwrap();
417 cx.run_until_parked();
418 let mut pending_completions = fake_model.pending_completions();
419 assert_eq!(
420 pending_completions.len(),
421 1,
422 "unexpected pending completions: {:?}",
423 pending_completions
424 );
425
426 let pending_completion = pending_completions.pop().unwrap();
427 assert_eq!(pending_completion.messages[0].role, Role::System);
428
429 let system_message = &pending_completion.messages[0];
430 let system_prompt = system_message.content[0].to_str().unwrap();
431 assert!(
432 system_prompt.contains("test-shell"),
433 "unexpected system message: {:?}",
434 system_message
435 );
436 assert!(
437 system_prompt.contains("## Fixing Diagnostics"),
438 "unexpected system message: {:?}",
439 system_message
440 );
441}
442
443#[gpui::test]
444async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
445 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
446 let fake_model = model.as_fake();
447
448 thread
449 .update(cx, |thread, cx| {
450 thread.send(UserMessageId::new(), ["abc"], cx)
451 })
452 .unwrap();
453 cx.run_until_parked();
454 let mut pending_completions = fake_model.pending_completions();
455 assert_eq!(
456 pending_completions.len(),
457 1,
458 "unexpected pending completions: {:?}",
459 pending_completions
460 );
461
462 let pending_completion = pending_completions.pop().unwrap();
463 assert_eq!(pending_completion.messages[0].role, Role::System);
464
465 let system_message = &pending_completion.messages[0];
466 let system_prompt = system_message.content[0].to_str().unwrap();
467 assert!(
468 !system_prompt.contains("## Tool Use"),
469 "unexpected system message: {:?}",
470 system_message
471 );
472 assert!(
473 !system_prompt.contains("## Fixing Diagnostics"),
474 "unexpected system message: {:?}",
475 system_message
476 );
477}
478
479#[gpui::test]
480async fn test_prompt_caching(cx: &mut TestAppContext) {
481 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
482 let fake_model = model.as_fake();
483
484 // Send initial user message and verify it's cached
485 thread
486 .update(cx, |thread, cx| {
487 thread.send(UserMessageId::new(), ["Message 1"], cx)
488 })
489 .unwrap();
490 cx.run_until_parked();
491
492 let completion = fake_model.pending_completions().pop().unwrap();
493 assert_eq!(
494 completion.messages[1..],
495 vec![LanguageModelRequestMessage {
496 role: Role::User,
497 content: vec!["Message 1".into()],
498 cache: true,
499 reasoning_details: None,
500 }]
501 );
502 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
503 "Response to Message 1".into(),
504 ));
505 fake_model.end_last_completion_stream();
506 cx.run_until_parked();
507
508 // Send another user message and verify only the latest is cached
509 thread
510 .update(cx, |thread, cx| {
511 thread.send(UserMessageId::new(), ["Message 2"], cx)
512 })
513 .unwrap();
514 cx.run_until_parked();
515
516 let completion = fake_model.pending_completions().pop().unwrap();
517 assert_eq!(
518 completion.messages[1..],
519 vec![
520 LanguageModelRequestMessage {
521 role: Role::User,
522 content: vec!["Message 1".into()],
523 cache: false,
524 reasoning_details: None,
525 },
526 LanguageModelRequestMessage {
527 role: Role::Assistant,
528 content: vec!["Response to Message 1".into()],
529 cache: false,
530 reasoning_details: None,
531 },
532 LanguageModelRequestMessage {
533 role: Role::User,
534 content: vec!["Message 2".into()],
535 cache: true,
536 reasoning_details: None,
537 }
538 ]
539 );
540 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
541 "Response to Message 2".into(),
542 ));
543 fake_model.end_last_completion_stream();
544 cx.run_until_parked();
545
546 // Simulate a tool call and verify that the latest tool result is cached
547 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
548 thread
549 .update(cx, |thread, cx| {
550 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
551 })
552 .unwrap();
553 cx.run_until_parked();
554
555 let tool_use = LanguageModelToolUse {
556 id: "tool_1".into(),
557 name: EchoTool::name().into(),
558 raw_input: json!({"text": "test"}).to_string(),
559 input: json!({"text": "test"}),
560 is_input_complete: true,
561 thought_signature: None,
562 };
563 fake_model
564 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
565 fake_model.end_last_completion_stream();
566 cx.run_until_parked();
567
568 let completion = fake_model.pending_completions().pop().unwrap();
569 let tool_result = LanguageModelToolResult {
570 tool_use_id: "tool_1".into(),
571 tool_name: EchoTool::name().into(),
572 is_error: false,
573 content: "test".into(),
574 output: Some("test".into()),
575 };
576 assert_eq!(
577 completion.messages[1..],
578 vec![
579 LanguageModelRequestMessage {
580 role: Role::User,
581 content: vec!["Message 1".into()],
582 cache: false,
583 reasoning_details: None,
584 },
585 LanguageModelRequestMessage {
586 role: Role::Assistant,
587 content: vec!["Response to Message 1".into()],
588 cache: false,
589 reasoning_details: None,
590 },
591 LanguageModelRequestMessage {
592 role: Role::User,
593 content: vec!["Message 2".into()],
594 cache: false,
595 reasoning_details: None,
596 },
597 LanguageModelRequestMessage {
598 role: Role::Assistant,
599 content: vec!["Response to Message 2".into()],
600 cache: false,
601 reasoning_details: None,
602 },
603 LanguageModelRequestMessage {
604 role: Role::User,
605 content: vec!["Use the echo tool".into()],
606 cache: false,
607 reasoning_details: None,
608 },
609 LanguageModelRequestMessage {
610 role: Role::Assistant,
611 content: vec![MessageContent::ToolUse(tool_use)],
612 cache: false,
613 reasoning_details: None,
614 },
615 LanguageModelRequestMessage {
616 role: Role::User,
617 content: vec![MessageContent::ToolResult(tool_result)],
618 cache: true,
619 reasoning_details: None,
620 }
621 ]
622 );
623}
624
625#[gpui::test]
626#[cfg_attr(not(feature = "e2e"), ignore)]
627async fn test_basic_tool_calls(cx: &mut TestAppContext) {
628 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
629
630 // Test a tool call that's likely to complete *before* streaming stops.
631 let events = thread
632 .update(cx, |thread, cx| {
633 thread.add_tool(EchoTool);
634 thread.send(
635 UserMessageId::new(),
636 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
637 cx,
638 )
639 })
640 .unwrap()
641 .collect()
642 .await;
643 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
644
645 // Test a tool calls that's likely to complete *after* streaming stops.
646 let events = thread
647 .update(cx, |thread, cx| {
648 thread.remove_tool(&EchoTool::name());
649 thread.add_tool(DelayTool);
650 thread.send(
651 UserMessageId::new(),
652 [
653 "Now call the delay tool with 200ms.",
654 "When the timer goes off, then you echo the output of the tool.",
655 ],
656 cx,
657 )
658 })
659 .unwrap()
660 .collect()
661 .await;
662 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
663 thread.update(cx, |thread, _cx| {
664 assert!(
665 thread
666 .last_message()
667 .unwrap()
668 .as_agent_message()
669 .unwrap()
670 .content
671 .iter()
672 .any(|content| {
673 if let AgentMessageContent::Text(text) = content {
674 text.contains("Ding")
675 } else {
676 false
677 }
678 }),
679 "{}",
680 thread.to_markdown()
681 );
682 });
683}
684
685#[gpui::test]
686#[cfg_attr(not(feature = "e2e"), ignore)]
687async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
688 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
689
690 // Test a tool call that's likely to complete *before* streaming stops.
691 let mut events = thread
692 .update(cx, |thread, cx| {
693 thread.add_tool(WordListTool);
694 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
695 })
696 .unwrap();
697
698 let mut saw_partial_tool_use = false;
699 while let Some(event) = events.next().await {
700 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
701 thread.update(cx, |thread, _cx| {
702 // Look for a tool use in the thread's last message
703 let message = thread.last_message().unwrap();
704 let agent_message = message.as_agent_message().unwrap();
705 let last_content = agent_message.content.last().unwrap();
706 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
707 assert_eq!(last_tool_use.name.as_ref(), "word_list");
708 if tool_call.status == acp::ToolCallStatus::Pending {
709 if !last_tool_use.is_input_complete
710 && last_tool_use.input.get("g").is_none()
711 {
712 saw_partial_tool_use = true;
713 }
714 } else {
715 last_tool_use
716 .input
717 .get("a")
718 .expect("'a' has streamed because input is now complete");
719 last_tool_use
720 .input
721 .get("g")
722 .expect("'g' has streamed because input is now complete");
723 }
724 } else {
725 panic!("last content should be a tool use");
726 }
727 });
728 }
729 }
730
731 assert!(
732 saw_partial_tool_use,
733 "should see at least one partially streamed tool use in the history"
734 );
735}
736
737#[gpui::test]
738async fn test_tool_authorization(cx: &mut TestAppContext) {
739 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
740 let fake_model = model.as_fake();
741
742 let mut events = thread
743 .update(cx, |thread, cx| {
744 thread.add_tool(ToolRequiringPermission);
745 thread.send(UserMessageId::new(), ["abc"], cx)
746 })
747 .unwrap();
748 cx.run_until_parked();
749 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
750 LanguageModelToolUse {
751 id: "tool_id_1".into(),
752 name: ToolRequiringPermission::name().into(),
753 raw_input: "{}".into(),
754 input: json!({}),
755 is_input_complete: true,
756 thought_signature: None,
757 },
758 ));
759 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
760 LanguageModelToolUse {
761 id: "tool_id_2".into(),
762 name: ToolRequiringPermission::name().into(),
763 raw_input: "{}".into(),
764 input: json!({}),
765 is_input_complete: true,
766 thought_signature: None,
767 },
768 ));
769 fake_model.end_last_completion_stream();
770 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
771 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
772
773 // Approve the first
774 tool_call_auth_1
775 .response
776 .send(tool_call_auth_1.options[1].option_id.clone())
777 .unwrap();
778 cx.run_until_parked();
779
780 // Reject the second
781 tool_call_auth_2
782 .response
783 .send(tool_call_auth_1.options[2].option_id.clone())
784 .unwrap();
785 cx.run_until_parked();
786
787 let completion = fake_model.pending_completions().pop().unwrap();
788 let message = completion.messages.last().unwrap();
789 assert_eq!(
790 message.content,
791 vec![
792 language_model::MessageContent::ToolResult(LanguageModelToolResult {
793 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
794 tool_name: ToolRequiringPermission::name().into(),
795 is_error: false,
796 content: "Allowed".into(),
797 output: Some("Allowed".into())
798 }),
799 language_model::MessageContent::ToolResult(LanguageModelToolResult {
800 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
801 tool_name: ToolRequiringPermission::name().into(),
802 is_error: true,
803 content: "Permission to run tool denied by user".into(),
804 output: Some("Permission to run tool denied by user".into())
805 })
806 ]
807 );
808
809 // Simulate yet another tool call.
810 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
811 LanguageModelToolUse {
812 id: "tool_id_3".into(),
813 name: ToolRequiringPermission::name().into(),
814 raw_input: "{}".into(),
815 input: json!({}),
816 is_input_complete: true,
817 thought_signature: None,
818 },
819 ));
820 fake_model.end_last_completion_stream();
821
822 // Respond by always allowing tools.
823 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
824 tool_call_auth_3
825 .response
826 .send(tool_call_auth_3.options[0].option_id.clone())
827 .unwrap();
828 cx.run_until_parked();
829 let completion = fake_model.pending_completions().pop().unwrap();
830 let message = completion.messages.last().unwrap();
831 assert_eq!(
832 message.content,
833 vec![language_model::MessageContent::ToolResult(
834 LanguageModelToolResult {
835 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
836 tool_name: ToolRequiringPermission::name().into(),
837 is_error: false,
838 content: "Allowed".into(),
839 output: Some("Allowed".into())
840 }
841 )]
842 );
843
844 // Simulate a final tool call, ensuring we don't trigger authorization.
845 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
846 LanguageModelToolUse {
847 id: "tool_id_4".into(),
848 name: ToolRequiringPermission::name().into(),
849 raw_input: "{}".into(),
850 input: json!({}),
851 is_input_complete: true,
852 thought_signature: None,
853 },
854 ));
855 fake_model.end_last_completion_stream();
856 cx.run_until_parked();
857 let completion = fake_model.pending_completions().pop().unwrap();
858 let message = completion.messages.last().unwrap();
859 assert_eq!(
860 message.content,
861 vec![language_model::MessageContent::ToolResult(
862 LanguageModelToolResult {
863 tool_use_id: "tool_id_4".into(),
864 tool_name: ToolRequiringPermission::name().into(),
865 is_error: false,
866 content: "Allowed".into(),
867 output: Some("Allowed".into())
868 }
869 )]
870 );
871}
872
873#[gpui::test]
874async fn test_tool_hallucination(cx: &mut TestAppContext) {
875 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
876 let fake_model = model.as_fake();
877
878 let mut events = thread
879 .update(cx, |thread, cx| {
880 thread.send(UserMessageId::new(), ["abc"], cx)
881 })
882 .unwrap();
883 cx.run_until_parked();
884 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
885 LanguageModelToolUse {
886 id: "tool_id_1".into(),
887 name: "nonexistent_tool".into(),
888 raw_input: "{}".into(),
889 input: json!({}),
890 is_input_complete: true,
891 thought_signature: None,
892 },
893 ));
894 fake_model.end_last_completion_stream();
895
896 let tool_call = expect_tool_call(&mut events).await;
897 assert_eq!(tool_call.title, "nonexistent_tool");
898 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
899 let update = expect_tool_call_update_fields(&mut events).await;
900 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
901}
902
903#[gpui::test]
904async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
905 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
906 let fake_model = model.as_fake();
907
908 let events = thread
909 .update(cx, |thread, cx| {
910 thread.add_tool(EchoTool);
911 thread.send(UserMessageId::new(), ["abc"], cx)
912 })
913 .unwrap();
914 cx.run_until_parked();
915 let tool_use = LanguageModelToolUse {
916 id: "tool_id_1".into(),
917 name: EchoTool::name().into(),
918 raw_input: "{}".into(),
919 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
920 is_input_complete: true,
921 thought_signature: None,
922 };
923 fake_model
924 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
925 fake_model.end_last_completion_stream();
926
927 cx.run_until_parked();
928 let completion = fake_model.pending_completions().pop().unwrap();
929 let tool_result = LanguageModelToolResult {
930 tool_use_id: "tool_id_1".into(),
931 tool_name: EchoTool::name().into(),
932 is_error: false,
933 content: "def".into(),
934 output: Some("def".into()),
935 };
936 assert_eq!(
937 completion.messages[1..],
938 vec![
939 LanguageModelRequestMessage {
940 role: Role::User,
941 content: vec!["abc".into()],
942 cache: false,
943 reasoning_details: None,
944 },
945 LanguageModelRequestMessage {
946 role: Role::Assistant,
947 content: vec![MessageContent::ToolUse(tool_use.clone())],
948 cache: false,
949 reasoning_details: None,
950 },
951 LanguageModelRequestMessage {
952 role: Role::User,
953 content: vec![MessageContent::ToolResult(tool_result.clone())],
954 cache: true,
955 reasoning_details: None,
956 },
957 ]
958 );
959
960 // Simulate reaching tool use limit.
961 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
962 fake_model.end_last_completion_stream();
963 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
964 assert!(
965 last_event
966 .unwrap_err()
967 .is::<language_model::ToolUseLimitReachedError>()
968 );
969
970 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
971 cx.run_until_parked();
972 let completion = fake_model.pending_completions().pop().unwrap();
973 assert_eq!(
974 completion.messages[1..],
975 vec![
976 LanguageModelRequestMessage {
977 role: Role::User,
978 content: vec!["abc".into()],
979 cache: false,
980 reasoning_details: None,
981 },
982 LanguageModelRequestMessage {
983 role: Role::Assistant,
984 content: vec![MessageContent::ToolUse(tool_use)],
985 cache: false,
986 reasoning_details: None,
987 },
988 LanguageModelRequestMessage {
989 role: Role::User,
990 content: vec![MessageContent::ToolResult(tool_result)],
991 cache: false,
992 reasoning_details: None,
993 },
994 LanguageModelRequestMessage {
995 role: Role::User,
996 content: vec!["Continue where you left off".into()],
997 cache: true,
998 reasoning_details: None,
999 }
1000 ]
1001 );
1002
1003 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
1004 fake_model.end_last_completion_stream();
1005 events.collect::<Vec<_>>().await;
1006 thread.read_with(cx, |thread, _cx| {
1007 assert_eq!(
1008 thread.last_message().unwrap().to_markdown(),
1009 indoc! {"
1010 ## Assistant
1011
1012 Done
1013 "}
1014 )
1015 });
1016}
1017
1018#[gpui::test]
1019async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
1020 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1021 let fake_model = model.as_fake();
1022
1023 let events = thread
1024 .update(cx, |thread, cx| {
1025 thread.add_tool(EchoTool);
1026 thread.send(UserMessageId::new(), ["abc"], cx)
1027 })
1028 .unwrap();
1029 cx.run_until_parked();
1030
1031 let tool_use = LanguageModelToolUse {
1032 id: "tool_id_1".into(),
1033 name: EchoTool::name().into(),
1034 raw_input: "{}".into(),
1035 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
1036 is_input_complete: true,
1037 thought_signature: None,
1038 };
1039 let tool_result = LanguageModelToolResult {
1040 tool_use_id: "tool_id_1".into(),
1041 tool_name: EchoTool::name().into(),
1042 is_error: false,
1043 content: "def".into(),
1044 output: Some("def".into()),
1045 };
1046 fake_model
1047 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
1048 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
1049 fake_model.end_last_completion_stream();
1050 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
1051 assert!(
1052 last_event
1053 .unwrap_err()
1054 .is::<language_model::ToolUseLimitReachedError>()
1055 );
1056
1057 thread
1058 .update(cx, |thread, cx| {
1059 thread.send(UserMessageId::new(), vec!["ghi"], cx)
1060 })
1061 .unwrap();
1062 cx.run_until_parked();
1063 let completion = fake_model.pending_completions().pop().unwrap();
1064 assert_eq!(
1065 completion.messages[1..],
1066 vec![
1067 LanguageModelRequestMessage {
1068 role: Role::User,
1069 content: vec!["abc".into()],
1070 cache: false,
1071 reasoning_details: None,
1072 },
1073 LanguageModelRequestMessage {
1074 role: Role::Assistant,
1075 content: vec![MessageContent::ToolUse(tool_use)],
1076 cache: false,
1077 reasoning_details: None,
1078 },
1079 LanguageModelRequestMessage {
1080 role: Role::User,
1081 content: vec![MessageContent::ToolResult(tool_result)],
1082 cache: false,
1083 reasoning_details: None,
1084 },
1085 LanguageModelRequestMessage {
1086 role: Role::User,
1087 content: vec!["ghi".into()],
1088 cache: true,
1089 reasoning_details: None,
1090 }
1091 ]
1092 );
1093}
1094
1095async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
1096 let event = events
1097 .next()
1098 .await
1099 .expect("no tool call authorization event received")
1100 .unwrap();
1101 match event {
1102 ThreadEvent::ToolCall(tool_call) => tool_call,
1103 event => {
1104 panic!("Unexpected event {event:?}");
1105 }
1106 }
1107}
1108
1109async fn expect_tool_call_update_fields(
1110 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1111) -> acp::ToolCallUpdate {
1112 let event = events
1113 .next()
1114 .await
1115 .expect("no tool call authorization event received")
1116 .unwrap();
1117 match event {
1118 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
1119 event => {
1120 panic!("Unexpected event {event:?}");
1121 }
1122 }
1123}
1124
1125async fn next_tool_call_authorization(
1126 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1127) -> ToolCallAuthorization {
1128 loop {
1129 let event = events
1130 .next()
1131 .await
1132 .expect("no tool call authorization event received")
1133 .unwrap();
1134 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1135 let permission_kinds = tool_call_authorization
1136 .options
1137 .iter()
1138 .map(|o| o.kind)
1139 .collect::<Vec<_>>();
1140 assert_eq!(
1141 permission_kinds,
1142 vec![
1143 acp::PermissionOptionKind::AllowAlways,
1144 acp::PermissionOptionKind::AllowOnce,
1145 acp::PermissionOptionKind::RejectOnce,
1146 ]
1147 );
1148 return tool_call_authorization;
1149 }
1150 }
1151}
1152
1153#[gpui::test]
1154#[cfg_attr(not(feature = "e2e"), ignore)]
1155async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1156 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1157
1158 // Test concurrent tool calls with different delay times
1159 let events = thread
1160 .update(cx, |thread, cx| {
1161 thread.add_tool(DelayTool);
1162 thread.send(
1163 UserMessageId::new(),
1164 [
1165 "Call the delay tool twice in the same message.",
1166 "Once with 100ms. Once with 300ms.",
1167 "When both timers are complete, describe the outputs.",
1168 ],
1169 cx,
1170 )
1171 })
1172 .unwrap()
1173 .collect()
1174 .await;
1175
1176 let stop_reasons = stop_events(events);
1177 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1178
1179 thread.update(cx, |thread, _cx| {
1180 let last_message = thread.last_message().unwrap();
1181 let agent_message = last_message.as_agent_message().unwrap();
1182 let text = agent_message
1183 .content
1184 .iter()
1185 .filter_map(|content| {
1186 if let AgentMessageContent::Text(text) = content {
1187 Some(text.as_str())
1188 } else {
1189 None
1190 }
1191 })
1192 .collect::<String>();
1193
1194 assert!(text.contains("Ding"));
1195 });
1196}
1197
1198#[gpui::test]
1199async fn test_profiles(cx: &mut TestAppContext) {
1200 let ThreadTest {
1201 model, thread, fs, ..
1202 } = setup(cx, TestModel::Fake).await;
1203 let fake_model = model.as_fake();
1204
1205 thread.update(cx, |thread, _cx| {
1206 thread.add_tool(DelayTool);
1207 thread.add_tool(EchoTool);
1208 thread.add_tool(InfiniteTool);
1209 });
1210
1211 // Override profiles and wait for settings to be loaded.
1212 fs.insert_file(
1213 paths::settings_file(),
1214 json!({
1215 "agent": {
1216 "profiles": {
1217 "test-1": {
1218 "name": "Test Profile 1",
1219 "tools": {
1220 EchoTool::name(): true,
1221 DelayTool::name(): true,
1222 }
1223 },
1224 "test-2": {
1225 "name": "Test Profile 2",
1226 "tools": {
1227 InfiniteTool::name(): true,
1228 }
1229 }
1230 }
1231 }
1232 })
1233 .to_string()
1234 .into_bytes(),
1235 )
1236 .await;
1237 cx.run_until_parked();
1238
1239 // Test that test-1 profile (default) has echo and delay tools
1240 thread
1241 .update(cx, |thread, cx| {
1242 thread.set_profile(AgentProfileId("test-1".into()), cx);
1243 thread.send(UserMessageId::new(), ["test"], cx)
1244 })
1245 .unwrap();
1246 cx.run_until_parked();
1247
1248 let mut pending_completions = fake_model.pending_completions();
1249 assert_eq!(pending_completions.len(), 1);
1250 let completion = pending_completions.pop().unwrap();
1251 let tool_names: Vec<String> = completion
1252 .tools
1253 .iter()
1254 .map(|tool| tool.name.clone())
1255 .collect();
1256 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
1257 fake_model.end_last_completion_stream();
1258
1259 // Switch to test-2 profile, and verify that it has only the infinite tool.
1260 thread
1261 .update(cx, |thread, cx| {
1262 thread.set_profile(AgentProfileId("test-2".into()), cx);
1263 thread.send(UserMessageId::new(), ["test2"], cx)
1264 })
1265 .unwrap();
1266 cx.run_until_parked();
1267 let mut pending_completions = fake_model.pending_completions();
1268 assert_eq!(pending_completions.len(), 1);
1269 let completion = pending_completions.pop().unwrap();
1270 let tool_names: Vec<String> = completion
1271 .tools
1272 .iter()
1273 .map(|tool| tool.name.clone())
1274 .collect();
1275 assert_eq!(tool_names, vec![InfiniteTool::name()]);
1276}
1277
1278#[gpui::test]
1279async fn test_mcp_tools(cx: &mut TestAppContext) {
1280 let ThreadTest {
1281 model,
1282 thread,
1283 context_server_store,
1284 fs,
1285 ..
1286 } = setup(cx, TestModel::Fake).await;
1287 let fake_model = model.as_fake();
1288
1289 // Override profiles and wait for settings to be loaded.
1290 fs.insert_file(
1291 paths::settings_file(),
1292 json!({
1293 "agent": {
1294 "always_allow_tool_actions": true,
1295 "profiles": {
1296 "test": {
1297 "name": "Test Profile",
1298 "enable_all_context_servers": true,
1299 "tools": {
1300 EchoTool::name(): true,
1301 }
1302 },
1303 }
1304 }
1305 })
1306 .to_string()
1307 .into_bytes(),
1308 )
1309 .await;
1310 cx.run_until_parked();
1311 thread.update(cx, |thread, cx| {
1312 thread.set_profile(AgentProfileId("test".into()), cx)
1313 });
1314
1315 let mut mcp_tool_calls = setup_context_server(
1316 "test_server",
1317 vec![context_server::types::Tool {
1318 name: "echo".into(),
1319 description: None,
1320 input_schema: serde_json::to_value(EchoTool::input_schema(
1321 LanguageModelToolSchemaFormat::JsonSchema,
1322 ))
1323 .unwrap(),
1324 output_schema: None,
1325 annotations: None,
1326 }],
1327 &context_server_store,
1328 cx,
1329 );
1330
1331 let events = thread.update(cx, |thread, cx| {
1332 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1333 });
1334 cx.run_until_parked();
1335
1336 // Simulate the model calling the MCP tool.
1337 let completion = fake_model.pending_completions().pop().unwrap();
1338 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1339 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1340 LanguageModelToolUse {
1341 id: "tool_1".into(),
1342 name: "echo".into(),
1343 raw_input: json!({"text": "test"}).to_string(),
1344 input: json!({"text": "test"}),
1345 is_input_complete: true,
1346 thought_signature: None,
1347 },
1348 ));
1349 fake_model.end_last_completion_stream();
1350 cx.run_until_parked();
1351
1352 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1353 assert_eq!(tool_call_params.name, "echo");
1354 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1355 tool_call_response
1356 .send(context_server::types::CallToolResponse {
1357 content: vec![context_server::types::ToolResponseContent::Text {
1358 text: "test".into(),
1359 }],
1360 is_error: None,
1361 meta: None,
1362 structured_content: None,
1363 })
1364 .unwrap();
1365 cx.run_until_parked();
1366
1367 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1368 fake_model.send_last_completion_stream_text_chunk("Done!");
1369 fake_model.end_last_completion_stream();
1370 events.collect::<Vec<_>>().await;
1371
1372 // Send again after adding the echo tool, ensuring the name collision is resolved.
1373 let events = thread.update(cx, |thread, cx| {
1374 thread.add_tool(EchoTool);
1375 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1376 });
1377 cx.run_until_parked();
1378 let completion = fake_model.pending_completions().pop().unwrap();
1379 assert_eq!(
1380 tool_names_for_completion(&completion),
1381 vec!["echo", "test_server_echo"]
1382 );
1383 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1384 LanguageModelToolUse {
1385 id: "tool_2".into(),
1386 name: "test_server_echo".into(),
1387 raw_input: json!({"text": "mcp"}).to_string(),
1388 input: json!({"text": "mcp"}),
1389 is_input_complete: true,
1390 thought_signature: None,
1391 },
1392 ));
1393 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1394 LanguageModelToolUse {
1395 id: "tool_3".into(),
1396 name: "echo".into(),
1397 raw_input: json!({"text": "native"}).to_string(),
1398 input: json!({"text": "native"}),
1399 is_input_complete: true,
1400 thought_signature: None,
1401 },
1402 ));
1403 fake_model.end_last_completion_stream();
1404 cx.run_until_parked();
1405
1406 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1407 assert_eq!(tool_call_params.name, "echo");
1408 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1409 tool_call_response
1410 .send(context_server::types::CallToolResponse {
1411 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1412 is_error: None,
1413 meta: None,
1414 structured_content: None,
1415 })
1416 .unwrap();
1417 cx.run_until_parked();
1418
1419 // Ensure the tool results were inserted with the correct names.
1420 let completion = fake_model.pending_completions().pop().unwrap();
1421 assert_eq!(
1422 completion.messages.last().unwrap().content,
1423 vec![
1424 MessageContent::ToolResult(LanguageModelToolResult {
1425 tool_use_id: "tool_3".into(),
1426 tool_name: "echo".into(),
1427 is_error: false,
1428 content: "native".into(),
1429 output: Some("native".into()),
1430 },),
1431 MessageContent::ToolResult(LanguageModelToolResult {
1432 tool_use_id: "tool_2".into(),
1433 tool_name: "test_server_echo".into(),
1434 is_error: false,
1435 content: "mcp".into(),
1436 output: Some("mcp".into()),
1437 },),
1438 ]
1439 );
1440 fake_model.end_last_completion_stream();
1441 events.collect::<Vec<_>>().await;
1442}
1443
1444#[gpui::test]
1445async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1446 let ThreadTest {
1447 model,
1448 thread,
1449 context_server_store,
1450 fs,
1451 ..
1452 } = setup(cx, TestModel::Fake).await;
1453 let fake_model = model.as_fake();
1454
1455 // Set up a profile with all tools enabled
1456 fs.insert_file(
1457 paths::settings_file(),
1458 json!({
1459 "agent": {
1460 "profiles": {
1461 "test": {
1462 "name": "Test Profile",
1463 "enable_all_context_servers": true,
1464 "tools": {
1465 EchoTool::name(): true,
1466 DelayTool::name(): true,
1467 WordListTool::name(): true,
1468 ToolRequiringPermission::name(): true,
1469 InfiniteTool::name(): true,
1470 }
1471 },
1472 }
1473 }
1474 })
1475 .to_string()
1476 .into_bytes(),
1477 )
1478 .await;
1479 cx.run_until_parked();
1480
1481 thread.update(cx, |thread, cx| {
1482 thread.set_profile(AgentProfileId("test".into()), cx);
1483 thread.add_tool(EchoTool);
1484 thread.add_tool(DelayTool);
1485 thread.add_tool(WordListTool);
1486 thread.add_tool(ToolRequiringPermission);
1487 thread.add_tool(InfiniteTool);
1488 });
1489
1490 // Set up multiple context servers with some overlapping tool names
1491 let _server1_calls = setup_context_server(
1492 "xxx",
1493 vec![
1494 context_server::types::Tool {
1495 name: "echo".into(), // Conflicts with native EchoTool
1496 description: None,
1497 input_schema: serde_json::to_value(EchoTool::input_schema(
1498 LanguageModelToolSchemaFormat::JsonSchema,
1499 ))
1500 .unwrap(),
1501 output_schema: None,
1502 annotations: None,
1503 },
1504 context_server::types::Tool {
1505 name: "unique_tool_1".into(),
1506 description: None,
1507 input_schema: json!({"type": "object", "properties": {}}),
1508 output_schema: None,
1509 annotations: None,
1510 },
1511 ],
1512 &context_server_store,
1513 cx,
1514 );
1515
1516 let _server2_calls = setup_context_server(
1517 "yyy",
1518 vec![
1519 context_server::types::Tool {
1520 name: "echo".into(), // Also conflicts with native EchoTool
1521 description: None,
1522 input_schema: serde_json::to_value(EchoTool::input_schema(
1523 LanguageModelToolSchemaFormat::JsonSchema,
1524 ))
1525 .unwrap(),
1526 output_schema: None,
1527 annotations: None,
1528 },
1529 context_server::types::Tool {
1530 name: "unique_tool_2".into(),
1531 description: None,
1532 input_schema: json!({"type": "object", "properties": {}}),
1533 output_schema: None,
1534 annotations: None,
1535 },
1536 context_server::types::Tool {
1537 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1538 description: None,
1539 input_schema: json!({"type": "object", "properties": {}}),
1540 output_schema: None,
1541 annotations: None,
1542 },
1543 context_server::types::Tool {
1544 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1545 description: None,
1546 input_schema: json!({"type": "object", "properties": {}}),
1547 output_schema: None,
1548 annotations: None,
1549 },
1550 ],
1551 &context_server_store,
1552 cx,
1553 );
1554 let _server3_calls = setup_context_server(
1555 "zzz",
1556 vec![
1557 context_server::types::Tool {
1558 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1559 description: None,
1560 input_schema: json!({"type": "object", "properties": {}}),
1561 output_schema: None,
1562 annotations: None,
1563 },
1564 context_server::types::Tool {
1565 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1566 description: None,
1567 input_schema: json!({"type": "object", "properties": {}}),
1568 output_schema: None,
1569 annotations: None,
1570 },
1571 context_server::types::Tool {
1572 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1573 description: None,
1574 input_schema: json!({"type": "object", "properties": {}}),
1575 output_schema: None,
1576 annotations: None,
1577 },
1578 ],
1579 &context_server_store,
1580 cx,
1581 );
1582
1583 thread
1584 .update(cx, |thread, cx| {
1585 thread.send(UserMessageId::new(), ["Go"], cx)
1586 })
1587 .unwrap();
1588 cx.run_until_parked();
1589 let completion = fake_model.pending_completions().pop().unwrap();
1590 assert_eq!(
1591 tool_names_for_completion(&completion),
1592 vec![
1593 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1594 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1595 "delay",
1596 "echo",
1597 "infinite",
1598 "tool_requiring_permission",
1599 "unique_tool_1",
1600 "unique_tool_2",
1601 "word_list",
1602 "xxx_echo",
1603 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1604 "yyy_echo",
1605 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1606 ]
1607 );
1608}
1609
1610#[gpui::test]
1611#[cfg_attr(not(feature = "e2e"), ignore)]
1612async fn test_cancellation(cx: &mut TestAppContext) {
1613 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1614
1615 let mut events = thread
1616 .update(cx, |thread, cx| {
1617 thread.add_tool(InfiniteTool);
1618 thread.add_tool(EchoTool);
1619 thread.send(
1620 UserMessageId::new(),
1621 ["Call the echo tool, then call the infinite tool, then explain their output"],
1622 cx,
1623 )
1624 })
1625 .unwrap();
1626
1627 // Wait until both tools are called.
1628 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1629 let mut echo_id = None;
1630 let mut echo_completed = false;
1631 while let Some(event) = events.next().await {
1632 match event.unwrap() {
1633 ThreadEvent::ToolCall(tool_call) => {
1634 assert_eq!(tool_call.title, expected_tools.remove(0));
1635 if tool_call.title == "Echo" {
1636 echo_id = Some(tool_call.tool_call_id);
1637 }
1638 }
1639 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1640 acp::ToolCallUpdate {
1641 tool_call_id,
1642 fields:
1643 acp::ToolCallUpdateFields {
1644 status: Some(acp::ToolCallStatus::Completed),
1645 ..
1646 },
1647 ..
1648 },
1649 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1650 echo_completed = true;
1651 }
1652 _ => {}
1653 }
1654
1655 if expected_tools.is_empty() && echo_completed {
1656 break;
1657 }
1658 }
1659
1660 // Cancel the current send and ensure that the event stream is closed, even
1661 // if one of the tools is still running.
1662 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1663 let events = events.collect::<Vec<_>>().await;
1664 let last_event = events.last();
1665 assert!(
1666 matches!(
1667 last_event,
1668 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1669 ),
1670 "unexpected event {last_event:?}"
1671 );
1672
1673 // Ensure we can still send a new message after cancellation.
1674 let events = thread
1675 .update(cx, |thread, cx| {
1676 thread.send(
1677 UserMessageId::new(),
1678 ["Testing: reply with 'Hello' then stop."],
1679 cx,
1680 )
1681 })
1682 .unwrap()
1683 .collect::<Vec<_>>()
1684 .await;
1685 thread.update(cx, |thread, _cx| {
1686 let message = thread.last_message().unwrap();
1687 let agent_message = message.as_agent_message().unwrap();
1688 assert_eq!(
1689 agent_message.content,
1690 vec![AgentMessageContent::Text("Hello".to_string())]
1691 );
1692 });
1693 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1694}
1695
1696#[gpui::test]
1697async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1698 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1699 always_allow_tools(cx);
1700 let fake_model = model.as_fake();
1701
1702 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1703 let environment = Rc::new(FakeThreadEnvironment {
1704 handle: handle.clone(),
1705 });
1706
1707 let mut events = thread
1708 .update(cx, |thread, cx| {
1709 thread.add_tool(crate::TerminalTool::new(
1710 thread.project().clone(),
1711 environment,
1712 ));
1713 thread.send(UserMessageId::new(), ["run a command"], cx)
1714 })
1715 .unwrap();
1716
1717 cx.run_until_parked();
1718
1719 // Simulate the model calling the terminal tool
1720 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1721 LanguageModelToolUse {
1722 id: "terminal_tool_1".into(),
1723 name: "terminal".into(),
1724 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1725 input: json!({"command": "sleep 1000", "cd": "."}),
1726 is_input_complete: true,
1727 thought_signature: None,
1728 },
1729 ));
1730 fake_model.end_last_completion_stream();
1731
1732 // Wait for the terminal tool to start running
1733 wait_for_terminal_tool_started(&mut events, cx).await;
1734
1735 // Cancel the thread while the terminal is running
1736 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1737
1738 // Collect remaining events, driving the executor to let cancellation complete
1739 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1740
1741 // Verify the terminal was killed
1742 assert!(
1743 handle.was_killed(),
1744 "expected terminal handle to be killed on cancellation"
1745 );
1746
1747 // Verify we got a cancellation stop event
1748 assert_eq!(
1749 stop_events(remaining_events),
1750 vec![acp::StopReason::Cancelled],
1751 );
1752
1753 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1754 thread.update(cx, |thread, _cx| {
1755 let message = thread.last_message().unwrap();
1756 let agent_message = message.as_agent_message().unwrap();
1757
1758 let tool_use = agent_message
1759 .content
1760 .iter()
1761 .find_map(|content| match content {
1762 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1763 _ => None,
1764 })
1765 .expect("expected tool use in agent message");
1766
1767 let tool_result = agent_message
1768 .tool_results
1769 .get(&tool_use.id)
1770 .expect("expected tool result");
1771
1772 let result_text = match &tool_result.content {
1773 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
1774 _ => panic!("expected text content in tool result"),
1775 };
1776
1777 // "partial output" comes from FakeTerminalHandle's output field
1778 assert!(
1779 result_text.contains("partial output"),
1780 "expected tool result to contain terminal output, got: {result_text}"
1781 );
1782 // Match the actual format from process_content in terminal_tool.rs
1783 assert!(
1784 result_text.contains("The user stopped this command"),
1785 "expected tool result to indicate user stopped, got: {result_text}"
1786 );
1787 });
1788
1789 // Verify we can send a new message after cancellation
1790 verify_thread_recovery(&thread, &fake_model, cx).await;
1791}
1792
1793#[gpui::test]
1794async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
1795 // This test verifies that tools which properly handle cancellation via
1796 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
1797 // to cancellation and report that they were cancelled.
1798 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1799 always_allow_tools(cx);
1800 let fake_model = model.as_fake();
1801
1802 let (tool, was_cancelled) = CancellationAwareTool::new();
1803
1804 let mut events = thread
1805 .update(cx, |thread, cx| {
1806 thread.add_tool(tool);
1807 thread.send(
1808 UserMessageId::new(),
1809 ["call the cancellation aware tool"],
1810 cx,
1811 )
1812 })
1813 .unwrap();
1814
1815 cx.run_until_parked();
1816
1817 // Simulate the model calling the cancellation-aware tool
1818 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1819 LanguageModelToolUse {
1820 id: "cancellation_aware_1".into(),
1821 name: "cancellation_aware".into(),
1822 raw_input: r#"{}"#.into(),
1823 input: json!({}),
1824 is_input_complete: true,
1825 thought_signature: None,
1826 },
1827 ));
1828 fake_model.end_last_completion_stream();
1829
1830 cx.run_until_parked();
1831
1832 // Wait for the tool call to be reported
1833 let mut tool_started = false;
1834 let deadline = cx.executor().num_cpus() * 100;
1835 for _ in 0..deadline {
1836 cx.run_until_parked();
1837
1838 while let Some(Some(event)) = events.next().now_or_never() {
1839 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
1840 if tool_call.title == "Cancellation Aware Tool" {
1841 tool_started = true;
1842 break;
1843 }
1844 }
1845 }
1846
1847 if tool_started {
1848 break;
1849 }
1850
1851 cx.background_executor
1852 .timer(Duration::from_millis(10))
1853 .await;
1854 }
1855 assert!(tool_started, "expected cancellation aware tool to start");
1856
1857 // Cancel the thread and wait for it to complete
1858 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
1859
1860 // The cancel task should complete promptly because the tool handles cancellation
1861 let timeout = cx.background_executor.timer(Duration::from_secs(5));
1862 futures::select! {
1863 _ = cancel_task.fuse() => {}
1864 _ = timeout.fuse() => {
1865 panic!("cancel task timed out - tool did not respond to cancellation");
1866 }
1867 }
1868
1869 // Verify the tool detected cancellation via its flag
1870 assert!(
1871 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
1872 "tool should have detected cancellation via event_stream.cancelled_by_user()"
1873 );
1874
1875 // Collect remaining events
1876 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1877
1878 // Verify we got a cancellation stop event
1879 assert_eq!(
1880 stop_events(remaining_events),
1881 vec![acp::StopReason::Cancelled],
1882 );
1883
1884 // Verify we can send a new message after cancellation
1885 verify_thread_recovery(&thread, &fake_model, cx).await;
1886}
1887
1888/// Helper to verify thread can recover after cancellation by sending a simple message.
1889async fn verify_thread_recovery(
1890 thread: &Entity<Thread>,
1891 fake_model: &FakeLanguageModel,
1892 cx: &mut TestAppContext,
1893) {
1894 let events = thread
1895 .update(cx, |thread, cx| {
1896 thread.send(
1897 UserMessageId::new(),
1898 ["Testing: reply with 'Hello' then stop."],
1899 cx,
1900 )
1901 })
1902 .unwrap();
1903 cx.run_until_parked();
1904 fake_model.send_last_completion_stream_text_chunk("Hello");
1905 fake_model
1906 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1907 fake_model.end_last_completion_stream();
1908
1909 let events = events.collect::<Vec<_>>().await;
1910 thread.update(cx, |thread, _cx| {
1911 let message = thread.last_message().unwrap();
1912 let agent_message = message.as_agent_message().unwrap();
1913 assert_eq!(
1914 agent_message.content,
1915 vec![AgentMessageContent::Text("Hello".to_string())]
1916 );
1917 });
1918 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1919}
1920
1921/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
1922async fn wait_for_terminal_tool_started(
1923 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1924 cx: &mut TestAppContext,
1925) {
1926 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
1927 for _ in 0..deadline {
1928 cx.run_until_parked();
1929
1930 while let Some(Some(event)) = events.next().now_or_never() {
1931 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1932 update,
1933 ))) = &event
1934 {
1935 if update.fields.content.as_ref().is_some_and(|content| {
1936 content
1937 .iter()
1938 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1939 }) {
1940 return;
1941 }
1942 }
1943 }
1944
1945 cx.background_executor
1946 .timer(Duration::from_millis(10))
1947 .await;
1948 }
1949 panic!("terminal tool did not start within the expected time");
1950}
1951
1952/// Collects events until a Stop event is received, driving the executor to completion.
1953async fn collect_events_until_stop(
1954 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1955 cx: &mut TestAppContext,
1956) -> Vec<Result<ThreadEvent>> {
1957 let mut collected = Vec::new();
1958 let deadline = cx.executor().num_cpus() * 200;
1959
1960 for _ in 0..deadline {
1961 cx.executor().advance_clock(Duration::from_millis(10));
1962 cx.run_until_parked();
1963
1964 while let Some(Some(event)) = events.next().now_or_never() {
1965 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
1966 collected.push(event);
1967 if is_stop {
1968 return collected;
1969 }
1970 }
1971 }
1972 panic!(
1973 "did not receive Stop event within the expected time; collected {} events",
1974 collected.len()
1975 );
1976}
1977
1978#[gpui::test]
1979async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
1980 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1981 always_allow_tools(cx);
1982 let fake_model = model.as_fake();
1983
1984 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1985 let environment = Rc::new(FakeThreadEnvironment {
1986 handle: handle.clone(),
1987 });
1988
1989 let message_id = UserMessageId::new();
1990 let mut events = thread
1991 .update(cx, |thread, cx| {
1992 thread.add_tool(crate::TerminalTool::new(
1993 thread.project().clone(),
1994 environment,
1995 ));
1996 thread.send(message_id.clone(), ["run a command"], cx)
1997 })
1998 .unwrap();
1999
2000 cx.run_until_parked();
2001
2002 // Simulate the model calling the terminal tool
2003 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2004 LanguageModelToolUse {
2005 id: "terminal_tool_1".into(),
2006 name: "terminal".into(),
2007 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2008 input: json!({"command": "sleep 1000", "cd": "."}),
2009 is_input_complete: true,
2010 thought_signature: None,
2011 },
2012 ));
2013 fake_model.end_last_completion_stream();
2014
2015 // Wait for the terminal tool to start running
2016 wait_for_terminal_tool_started(&mut events, cx).await;
2017
2018 // Truncate the thread while the terminal is running
2019 thread
2020 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2021 .unwrap();
2022
2023 // Drive the executor to let cancellation complete
2024 let _ = collect_events_until_stop(&mut events, cx).await;
2025
2026 // Verify the terminal was killed
2027 assert!(
2028 handle.was_killed(),
2029 "expected terminal handle to be killed on truncate"
2030 );
2031
2032 // Verify the thread is empty after truncation
2033 thread.update(cx, |thread, _cx| {
2034 assert_eq!(
2035 thread.to_markdown(),
2036 "",
2037 "expected thread to be empty after truncating the only message"
2038 );
2039 });
2040
2041 // Verify we can send a new message after truncation
2042 verify_thread_recovery(&thread, &fake_model, cx).await;
2043}
2044
2045#[gpui::test]
2046async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2047 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2048 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2049 always_allow_tools(cx);
2050 let fake_model = model.as_fake();
2051
2052 let environment = Rc::new(MultiTerminalEnvironment::new());
2053
2054 let mut events = thread
2055 .update(cx, |thread, cx| {
2056 thread.add_tool(crate::TerminalTool::new(
2057 thread.project().clone(),
2058 environment.clone(),
2059 ));
2060 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2061 })
2062 .unwrap();
2063
2064 cx.run_until_parked();
2065
2066 // Simulate the model calling two terminal tools
2067 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2068 LanguageModelToolUse {
2069 id: "terminal_tool_1".into(),
2070 name: "terminal".into(),
2071 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2072 input: json!({"command": "sleep 1000", "cd": "."}),
2073 is_input_complete: true,
2074 thought_signature: None,
2075 },
2076 ));
2077 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2078 LanguageModelToolUse {
2079 id: "terminal_tool_2".into(),
2080 name: "terminal".into(),
2081 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2082 input: json!({"command": "sleep 2000", "cd": "."}),
2083 is_input_complete: true,
2084 thought_signature: None,
2085 },
2086 ));
2087 fake_model.end_last_completion_stream();
2088
2089 // Wait for both terminal tools to start by counting terminal content updates
2090 let mut terminals_started = 0;
2091 let deadline = cx.executor().num_cpus() * 100;
2092 for _ in 0..deadline {
2093 cx.run_until_parked();
2094
2095 while let Some(Some(event)) = events.next().now_or_never() {
2096 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2097 update,
2098 ))) = &event
2099 {
2100 if update.fields.content.as_ref().is_some_and(|content| {
2101 content
2102 .iter()
2103 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2104 }) {
2105 terminals_started += 1;
2106 if terminals_started >= 2 {
2107 break;
2108 }
2109 }
2110 }
2111 }
2112 if terminals_started >= 2 {
2113 break;
2114 }
2115
2116 cx.background_executor
2117 .timer(Duration::from_millis(10))
2118 .await;
2119 }
2120 assert!(
2121 terminals_started >= 2,
2122 "expected 2 terminal tools to start, got {terminals_started}"
2123 );
2124
2125 // Cancel the thread while both terminals are running
2126 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2127
2128 // Collect remaining events
2129 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2130
2131 // Verify both terminal handles were killed
2132 let handles = environment.handles();
2133 assert_eq!(
2134 handles.len(),
2135 2,
2136 "expected 2 terminal handles to be created"
2137 );
2138 assert!(
2139 handles[0].was_killed(),
2140 "expected first terminal handle to be killed on cancellation"
2141 );
2142 assert!(
2143 handles[1].was_killed(),
2144 "expected second terminal handle to be killed on cancellation"
2145 );
2146
2147 // Verify we got a cancellation stop event
2148 assert_eq!(
2149 stop_events(remaining_events),
2150 vec![acp::StopReason::Cancelled],
2151 );
2152}
2153
2154#[gpui::test]
2155async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2156 // Tests that clicking the stop button on the terminal card (as opposed to the main
2157 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2158 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2159 always_allow_tools(cx);
2160 let fake_model = model.as_fake();
2161
2162 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2163 let environment = Rc::new(FakeThreadEnvironment {
2164 handle: handle.clone(),
2165 });
2166
2167 let mut events = thread
2168 .update(cx, |thread, cx| {
2169 thread.add_tool(crate::TerminalTool::new(
2170 thread.project().clone(),
2171 environment,
2172 ));
2173 thread.send(UserMessageId::new(), ["run a command"], cx)
2174 })
2175 .unwrap();
2176
2177 cx.run_until_parked();
2178
2179 // Simulate the model calling the terminal tool
2180 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2181 LanguageModelToolUse {
2182 id: "terminal_tool_1".into(),
2183 name: "terminal".into(),
2184 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2185 input: json!({"command": "sleep 1000", "cd": "."}),
2186 is_input_complete: true,
2187 thought_signature: None,
2188 },
2189 ));
2190 fake_model.end_last_completion_stream();
2191
2192 // Wait for the terminal tool to start running
2193 wait_for_terminal_tool_started(&mut events, cx).await;
2194
2195 // Simulate user clicking stop on the terminal card itself.
2196 // This sets the flag and signals exit (simulating what the real UI would do).
2197 handle.set_stopped_by_user(true);
2198 handle.killed.store(true, Ordering::SeqCst);
2199 handle.signal_exit();
2200
2201 // Wait for the tool to complete
2202 cx.run_until_parked();
2203
2204 // The thread continues after tool completion - simulate the model ending its turn
2205 fake_model
2206 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2207 fake_model.end_last_completion_stream();
2208
2209 // Collect remaining events
2210 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2211
2212 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2213 assert_eq!(
2214 stop_events(remaining_events),
2215 vec![acp::StopReason::EndTurn],
2216 );
2217
2218 // Verify the tool result indicates user stopped
2219 thread.update(cx, |thread, _cx| {
2220 let message = thread.last_message().unwrap();
2221 let agent_message = message.as_agent_message().unwrap();
2222
2223 let tool_use = agent_message
2224 .content
2225 .iter()
2226 .find_map(|content| match content {
2227 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2228 _ => None,
2229 })
2230 .expect("expected tool use in agent message");
2231
2232 let tool_result = agent_message
2233 .tool_results
2234 .get(&tool_use.id)
2235 .expect("expected tool result");
2236
2237 let result_text = match &tool_result.content {
2238 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2239 _ => panic!("expected text content in tool result"),
2240 };
2241
2242 assert!(
2243 result_text.contains("The user stopped this command"),
2244 "expected tool result to indicate user stopped, got: {result_text}"
2245 );
2246 });
2247}
2248
2249#[gpui::test]
2250async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2251 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2252 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2253 always_allow_tools(cx);
2254 let fake_model = model.as_fake();
2255
2256 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2257 let environment = Rc::new(FakeThreadEnvironment {
2258 handle: handle.clone(),
2259 });
2260
2261 let mut events = thread
2262 .update(cx, |thread, cx| {
2263 thread.add_tool(crate::TerminalTool::new(
2264 thread.project().clone(),
2265 environment,
2266 ));
2267 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2268 })
2269 .unwrap();
2270
2271 cx.run_until_parked();
2272
2273 // Simulate the model calling the terminal tool with a short timeout
2274 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2275 LanguageModelToolUse {
2276 id: "terminal_tool_1".into(),
2277 name: "terminal".into(),
2278 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2279 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2280 is_input_complete: true,
2281 thought_signature: None,
2282 },
2283 ));
2284 fake_model.end_last_completion_stream();
2285
2286 // Wait for the terminal tool to start running
2287 wait_for_terminal_tool_started(&mut events, cx).await;
2288
2289 // Advance clock past the timeout
2290 cx.executor().advance_clock(Duration::from_millis(200));
2291 cx.run_until_parked();
2292
2293 // The thread continues after tool completion - simulate the model ending its turn
2294 fake_model
2295 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2296 fake_model.end_last_completion_stream();
2297
2298 // Collect remaining events
2299 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2300
2301 // Verify the terminal was killed due to timeout
2302 assert!(
2303 handle.was_killed(),
2304 "expected terminal handle to be killed on timeout"
2305 );
2306
2307 // Verify we got an EndTurn (the tool completed, just with timeout)
2308 assert_eq!(
2309 stop_events(remaining_events),
2310 vec![acp::StopReason::EndTurn],
2311 );
2312
2313 // Verify the tool result indicates timeout, not user stopped
2314 thread.update(cx, |thread, _cx| {
2315 let message = thread.last_message().unwrap();
2316 let agent_message = message.as_agent_message().unwrap();
2317
2318 let tool_use = agent_message
2319 .content
2320 .iter()
2321 .find_map(|content| match content {
2322 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2323 _ => None,
2324 })
2325 .expect("expected tool use in agent message");
2326
2327 let tool_result = agent_message
2328 .tool_results
2329 .get(&tool_use.id)
2330 .expect("expected tool result");
2331
2332 let result_text = match &tool_result.content {
2333 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2334 _ => panic!("expected text content in tool result"),
2335 };
2336
2337 assert!(
2338 result_text.contains("timed out"),
2339 "expected tool result to indicate timeout, got: {result_text}"
2340 );
2341 assert!(
2342 !result_text.contains("The user stopped"),
2343 "tool result should not mention user stopped when it timed out, got: {result_text}"
2344 );
2345 });
2346}
2347
2348#[gpui::test]
2349async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2350 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2351 let fake_model = model.as_fake();
2352
2353 let events_1 = thread
2354 .update(cx, |thread, cx| {
2355 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2356 })
2357 .unwrap();
2358 cx.run_until_parked();
2359 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2360 cx.run_until_parked();
2361
2362 let events_2 = thread
2363 .update(cx, |thread, cx| {
2364 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2365 })
2366 .unwrap();
2367 cx.run_until_parked();
2368 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2369 fake_model
2370 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2371 fake_model.end_last_completion_stream();
2372
2373 let events_1 = events_1.collect::<Vec<_>>().await;
2374 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2375 let events_2 = events_2.collect::<Vec<_>>().await;
2376 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2377}
2378
2379#[gpui::test]
2380async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2381 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2382 let fake_model = model.as_fake();
2383
2384 let events_1 = thread
2385 .update(cx, |thread, cx| {
2386 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2387 })
2388 .unwrap();
2389 cx.run_until_parked();
2390 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2391 fake_model
2392 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2393 fake_model.end_last_completion_stream();
2394 let events_1 = events_1.collect::<Vec<_>>().await;
2395
2396 let events_2 = thread
2397 .update(cx, |thread, cx| {
2398 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2399 })
2400 .unwrap();
2401 cx.run_until_parked();
2402 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2403 fake_model
2404 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2405 fake_model.end_last_completion_stream();
2406 let events_2 = events_2.collect::<Vec<_>>().await;
2407
2408 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2409 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2410}
2411
2412#[gpui::test]
2413async fn test_refusal(cx: &mut TestAppContext) {
2414 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2415 let fake_model = model.as_fake();
2416
2417 let events = thread
2418 .update(cx, |thread, cx| {
2419 thread.send(UserMessageId::new(), ["Hello"], cx)
2420 })
2421 .unwrap();
2422 cx.run_until_parked();
2423 thread.read_with(cx, |thread, _| {
2424 assert_eq!(
2425 thread.to_markdown(),
2426 indoc! {"
2427 ## User
2428
2429 Hello
2430 "}
2431 );
2432 });
2433
2434 fake_model.send_last_completion_stream_text_chunk("Hey!");
2435 cx.run_until_parked();
2436 thread.read_with(cx, |thread, _| {
2437 assert_eq!(
2438 thread.to_markdown(),
2439 indoc! {"
2440 ## User
2441
2442 Hello
2443
2444 ## Assistant
2445
2446 Hey!
2447 "}
2448 );
2449 });
2450
2451 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2452 fake_model
2453 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2454 let events = events.collect::<Vec<_>>().await;
2455 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2456 thread.read_with(cx, |thread, _| {
2457 assert_eq!(thread.to_markdown(), "");
2458 });
2459}
2460
2461#[gpui::test]
2462async fn test_truncate_first_message(cx: &mut TestAppContext) {
2463 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2464 let fake_model = model.as_fake();
2465
2466 let message_id = UserMessageId::new();
2467 thread
2468 .update(cx, |thread, cx| {
2469 thread.send(message_id.clone(), ["Hello"], cx)
2470 })
2471 .unwrap();
2472 cx.run_until_parked();
2473 thread.read_with(cx, |thread, _| {
2474 assert_eq!(
2475 thread.to_markdown(),
2476 indoc! {"
2477 ## User
2478
2479 Hello
2480 "}
2481 );
2482 assert_eq!(thread.latest_token_usage(), None);
2483 });
2484
2485 fake_model.send_last_completion_stream_text_chunk("Hey!");
2486 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2487 language_model::TokenUsage {
2488 input_tokens: 32_000,
2489 output_tokens: 16_000,
2490 cache_creation_input_tokens: 0,
2491 cache_read_input_tokens: 0,
2492 },
2493 ));
2494 cx.run_until_parked();
2495 thread.read_with(cx, |thread, _| {
2496 assert_eq!(
2497 thread.to_markdown(),
2498 indoc! {"
2499 ## User
2500
2501 Hello
2502
2503 ## Assistant
2504
2505 Hey!
2506 "}
2507 );
2508 assert_eq!(
2509 thread.latest_token_usage(),
2510 Some(acp_thread::TokenUsage {
2511 used_tokens: 32_000 + 16_000,
2512 max_tokens: 1_000_000,
2513 output_tokens: 16_000,
2514 })
2515 );
2516 });
2517
2518 thread
2519 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2520 .unwrap();
2521 cx.run_until_parked();
2522 thread.read_with(cx, |thread, _| {
2523 assert_eq!(thread.to_markdown(), "");
2524 assert_eq!(thread.latest_token_usage(), None);
2525 });
2526
2527 // Ensure we can still send a new message after truncation.
2528 thread
2529 .update(cx, |thread, cx| {
2530 thread.send(UserMessageId::new(), ["Hi"], cx)
2531 })
2532 .unwrap();
2533 thread.update(cx, |thread, _cx| {
2534 assert_eq!(
2535 thread.to_markdown(),
2536 indoc! {"
2537 ## User
2538
2539 Hi
2540 "}
2541 );
2542 });
2543 cx.run_until_parked();
2544 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2545 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2546 language_model::TokenUsage {
2547 input_tokens: 40_000,
2548 output_tokens: 20_000,
2549 cache_creation_input_tokens: 0,
2550 cache_read_input_tokens: 0,
2551 },
2552 ));
2553 cx.run_until_parked();
2554 thread.read_with(cx, |thread, _| {
2555 assert_eq!(
2556 thread.to_markdown(),
2557 indoc! {"
2558 ## User
2559
2560 Hi
2561
2562 ## Assistant
2563
2564 Ahoy!
2565 "}
2566 );
2567
2568 assert_eq!(
2569 thread.latest_token_usage(),
2570 Some(acp_thread::TokenUsage {
2571 used_tokens: 40_000 + 20_000,
2572 max_tokens: 1_000_000,
2573 output_tokens: 20_000,
2574 })
2575 );
2576 });
2577}
2578
2579#[gpui::test]
2580async fn test_truncate_second_message(cx: &mut TestAppContext) {
2581 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2582 let fake_model = model.as_fake();
2583
2584 thread
2585 .update(cx, |thread, cx| {
2586 thread.send(UserMessageId::new(), ["Message 1"], cx)
2587 })
2588 .unwrap();
2589 cx.run_until_parked();
2590 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2591 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2592 language_model::TokenUsage {
2593 input_tokens: 32_000,
2594 output_tokens: 16_000,
2595 cache_creation_input_tokens: 0,
2596 cache_read_input_tokens: 0,
2597 },
2598 ));
2599 fake_model.end_last_completion_stream();
2600 cx.run_until_parked();
2601
2602 let assert_first_message_state = |cx: &mut TestAppContext| {
2603 thread.clone().read_with(cx, |thread, _| {
2604 assert_eq!(
2605 thread.to_markdown(),
2606 indoc! {"
2607 ## User
2608
2609 Message 1
2610
2611 ## Assistant
2612
2613 Message 1 response
2614 "}
2615 );
2616
2617 assert_eq!(
2618 thread.latest_token_usage(),
2619 Some(acp_thread::TokenUsage {
2620 used_tokens: 32_000 + 16_000,
2621 max_tokens: 1_000_000,
2622 output_tokens: 16_000,
2623 })
2624 );
2625 });
2626 };
2627
2628 assert_first_message_state(cx);
2629
2630 let second_message_id = UserMessageId::new();
2631 thread
2632 .update(cx, |thread, cx| {
2633 thread.send(second_message_id.clone(), ["Message 2"], cx)
2634 })
2635 .unwrap();
2636 cx.run_until_parked();
2637
2638 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2639 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2640 language_model::TokenUsage {
2641 input_tokens: 40_000,
2642 output_tokens: 20_000,
2643 cache_creation_input_tokens: 0,
2644 cache_read_input_tokens: 0,
2645 },
2646 ));
2647 fake_model.end_last_completion_stream();
2648 cx.run_until_parked();
2649
2650 thread.read_with(cx, |thread, _| {
2651 assert_eq!(
2652 thread.to_markdown(),
2653 indoc! {"
2654 ## User
2655
2656 Message 1
2657
2658 ## Assistant
2659
2660 Message 1 response
2661
2662 ## User
2663
2664 Message 2
2665
2666 ## Assistant
2667
2668 Message 2 response
2669 "}
2670 );
2671
2672 assert_eq!(
2673 thread.latest_token_usage(),
2674 Some(acp_thread::TokenUsage {
2675 used_tokens: 40_000 + 20_000,
2676 max_tokens: 1_000_000,
2677 output_tokens: 20_000,
2678 })
2679 );
2680 });
2681
2682 thread
2683 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2684 .unwrap();
2685 cx.run_until_parked();
2686
2687 assert_first_message_state(cx);
2688}
2689
2690#[gpui::test]
2691async fn test_title_generation(cx: &mut TestAppContext) {
2692 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2693 let fake_model = model.as_fake();
2694
2695 let summary_model = Arc::new(FakeLanguageModel::default());
2696 thread.update(cx, |thread, cx| {
2697 thread.set_summarization_model(Some(summary_model.clone()), cx)
2698 });
2699
2700 let send = thread
2701 .update(cx, |thread, cx| {
2702 thread.send(UserMessageId::new(), ["Hello"], cx)
2703 })
2704 .unwrap();
2705 cx.run_until_parked();
2706
2707 fake_model.send_last_completion_stream_text_chunk("Hey!");
2708 fake_model.end_last_completion_stream();
2709 cx.run_until_parked();
2710 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2711
2712 // Ensure the summary model has been invoked to generate a title.
2713 summary_model.send_last_completion_stream_text_chunk("Hello ");
2714 summary_model.send_last_completion_stream_text_chunk("world\nG");
2715 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2716 summary_model.end_last_completion_stream();
2717 send.collect::<Vec<_>>().await;
2718 cx.run_until_parked();
2719 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2720
2721 // Send another message, ensuring no title is generated this time.
2722 let send = thread
2723 .update(cx, |thread, cx| {
2724 thread.send(UserMessageId::new(), ["Hello again"], cx)
2725 })
2726 .unwrap();
2727 cx.run_until_parked();
2728 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2729 fake_model.end_last_completion_stream();
2730 cx.run_until_parked();
2731 assert_eq!(summary_model.pending_completions(), Vec::new());
2732 send.collect::<Vec<_>>().await;
2733 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2734}
2735
2736#[gpui::test]
2737async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2738 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2739 let fake_model = model.as_fake();
2740
2741 let _events = thread
2742 .update(cx, |thread, cx| {
2743 thread.add_tool(ToolRequiringPermission);
2744 thread.add_tool(EchoTool);
2745 thread.send(UserMessageId::new(), ["Hey!"], cx)
2746 })
2747 .unwrap();
2748 cx.run_until_parked();
2749
2750 let permission_tool_use = LanguageModelToolUse {
2751 id: "tool_id_1".into(),
2752 name: ToolRequiringPermission::name().into(),
2753 raw_input: "{}".into(),
2754 input: json!({}),
2755 is_input_complete: true,
2756 thought_signature: None,
2757 };
2758 let echo_tool_use = LanguageModelToolUse {
2759 id: "tool_id_2".into(),
2760 name: EchoTool::name().into(),
2761 raw_input: json!({"text": "test"}).to_string(),
2762 input: json!({"text": "test"}),
2763 is_input_complete: true,
2764 thought_signature: None,
2765 };
2766 fake_model.send_last_completion_stream_text_chunk("Hi!");
2767 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2768 permission_tool_use,
2769 ));
2770 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2771 echo_tool_use.clone(),
2772 ));
2773 fake_model.end_last_completion_stream();
2774 cx.run_until_parked();
2775
2776 // Ensure pending tools are skipped when building a request.
2777 let request = thread
2778 .read_with(cx, |thread, cx| {
2779 thread.build_completion_request(CompletionIntent::EditFile, cx)
2780 })
2781 .unwrap();
2782 assert_eq!(
2783 request.messages[1..],
2784 vec![
2785 LanguageModelRequestMessage {
2786 role: Role::User,
2787 content: vec!["Hey!".into()],
2788 cache: true,
2789 reasoning_details: None,
2790 },
2791 LanguageModelRequestMessage {
2792 role: Role::Assistant,
2793 content: vec![
2794 MessageContent::Text("Hi!".into()),
2795 MessageContent::ToolUse(echo_tool_use.clone())
2796 ],
2797 cache: false,
2798 reasoning_details: None,
2799 },
2800 LanguageModelRequestMessage {
2801 role: Role::User,
2802 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2803 tool_use_id: echo_tool_use.id.clone(),
2804 tool_name: echo_tool_use.name,
2805 is_error: false,
2806 content: "test".into(),
2807 output: Some("test".into())
2808 })],
2809 cache: false,
2810 reasoning_details: None,
2811 },
2812 ],
2813 );
2814}
2815
2816#[gpui::test]
2817async fn test_agent_connection(cx: &mut TestAppContext) {
2818 cx.update(settings::init);
2819 let templates = Templates::new();
2820
2821 // Initialize language model system with test provider
2822 cx.update(|cx| {
2823 gpui_tokio::init(cx);
2824
2825 let http_client = FakeHttpClient::with_404_response();
2826 let clock = Arc::new(clock::FakeSystemClock::new());
2827 let client = Client::new(clock, http_client, cx);
2828 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2829 language_model::init(client.clone(), cx);
2830 language_models::init(user_store, client.clone(), cx);
2831 LanguageModelRegistry::test(cx);
2832 });
2833 cx.executor().forbid_parking();
2834
2835 // Create a project for new_thread
2836 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2837 fake_fs.insert_tree(path!("/test"), json!({})).await;
2838 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2839 let cwd = Path::new("/test");
2840 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2841
2842 // Create agent and connection
2843 let agent = NativeAgent::new(
2844 project.clone(),
2845 thread_store,
2846 templates.clone(),
2847 None,
2848 fake_fs.clone(),
2849 &mut cx.to_async(),
2850 )
2851 .await
2852 .unwrap();
2853 let connection = NativeAgentConnection(agent.clone());
2854
2855 // Create a thread using new_thread
2856 let connection_rc = Rc::new(connection.clone());
2857 let acp_thread = cx
2858 .update(|cx| connection_rc.new_thread(project, cwd, cx))
2859 .await
2860 .expect("new_thread should succeed");
2861
2862 // Get the session_id from the AcpThread
2863 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2864
2865 // Test model_selector returns Some
2866 let selector_opt = connection.model_selector(&session_id);
2867 assert!(
2868 selector_opt.is_some(),
2869 "agent should always support ModelSelector"
2870 );
2871 let selector = selector_opt.unwrap();
2872
2873 // Test list_models
2874 let listed_models = cx
2875 .update(|cx| selector.list_models(cx))
2876 .await
2877 .expect("list_models should succeed");
2878 let AgentModelList::Grouped(listed_models) = listed_models else {
2879 panic!("Unexpected model list type");
2880 };
2881 assert!(!listed_models.is_empty(), "should have at least one model");
2882 assert_eq!(
2883 listed_models[&AgentModelGroupName("Fake".into())][0]
2884 .id
2885 .0
2886 .as_ref(),
2887 "fake/fake"
2888 );
2889
2890 // Test selected_model returns the default
2891 let model = cx
2892 .update(|cx| selector.selected_model(cx))
2893 .await
2894 .expect("selected_model should succeed");
2895 let model = cx
2896 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2897 .unwrap();
2898 let model = model.as_fake();
2899 assert_eq!(model.id().0, "fake", "should return default model");
2900
2901 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2902 cx.run_until_parked();
2903 model.send_last_completion_stream_text_chunk("def");
2904 cx.run_until_parked();
2905 acp_thread.read_with(cx, |thread, cx| {
2906 assert_eq!(
2907 thread.to_markdown(cx),
2908 indoc! {"
2909 ## User
2910
2911 abc
2912
2913 ## Assistant
2914
2915 def
2916
2917 "}
2918 )
2919 });
2920
2921 // Test cancel
2922 cx.update(|cx| connection.cancel(&session_id, cx));
2923 request.await.expect("prompt should fail gracefully");
2924
2925 // Ensure that dropping the ACP thread causes the native thread to be
2926 // dropped as well.
2927 cx.update(|_| drop(acp_thread));
2928 let result = cx
2929 .update(|cx| {
2930 connection.prompt(
2931 Some(acp_thread::UserMessageId::new()),
2932 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2933 cx,
2934 )
2935 })
2936 .await;
2937 assert_eq!(
2938 result.as_ref().unwrap_err().to_string(),
2939 "Session not found",
2940 "unexpected result: {:?}",
2941 result
2942 );
2943}
2944
2945#[gpui::test]
2946async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2947 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2948 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2949 let fake_model = model.as_fake();
2950
2951 let mut events = thread
2952 .update(cx, |thread, cx| {
2953 thread.send(UserMessageId::new(), ["Think"], cx)
2954 })
2955 .unwrap();
2956 cx.run_until_parked();
2957
2958 // Simulate streaming partial input.
2959 let input = json!({});
2960 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2961 LanguageModelToolUse {
2962 id: "1".into(),
2963 name: ThinkingTool::name().into(),
2964 raw_input: input.to_string(),
2965 input,
2966 is_input_complete: false,
2967 thought_signature: None,
2968 },
2969 ));
2970
2971 // Input streaming completed
2972 let input = json!({ "content": "Thinking hard!" });
2973 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2974 LanguageModelToolUse {
2975 id: "1".into(),
2976 name: "thinking".into(),
2977 raw_input: input.to_string(),
2978 input,
2979 is_input_complete: true,
2980 thought_signature: None,
2981 },
2982 ));
2983 fake_model.end_last_completion_stream();
2984 cx.run_until_parked();
2985
2986 let tool_call = expect_tool_call(&mut events).await;
2987 assert_eq!(
2988 tool_call,
2989 acp::ToolCall::new("1", "Thinking")
2990 .kind(acp::ToolKind::Think)
2991 .raw_input(json!({}))
2992 .meta(acp::Meta::from_iter([(
2993 "tool_name".into(),
2994 "thinking".into()
2995 )]))
2996 );
2997 let update = expect_tool_call_update_fields(&mut events).await;
2998 assert_eq!(
2999 update,
3000 acp::ToolCallUpdate::new(
3001 "1",
3002 acp::ToolCallUpdateFields::new()
3003 .title("Thinking")
3004 .kind(acp::ToolKind::Think)
3005 .raw_input(json!({ "content": "Thinking hard!"}))
3006 )
3007 );
3008 let update = expect_tool_call_update_fields(&mut events).await;
3009 assert_eq!(
3010 update,
3011 acp::ToolCallUpdate::new(
3012 "1",
3013 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3014 )
3015 );
3016 let update = expect_tool_call_update_fields(&mut events).await;
3017 assert_eq!(
3018 update,
3019 acp::ToolCallUpdate::new(
3020 "1",
3021 acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
3022 )
3023 );
3024 let update = expect_tool_call_update_fields(&mut events).await;
3025 assert_eq!(
3026 update,
3027 acp::ToolCallUpdate::new(
3028 "1",
3029 acp::ToolCallUpdateFields::new()
3030 .status(acp::ToolCallStatus::Completed)
3031 .raw_output("Finished thinking.")
3032 )
3033 );
3034}
3035
3036#[gpui::test]
3037async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3038 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3039 let fake_model = model.as_fake();
3040
3041 let mut events = thread
3042 .update(cx, |thread, cx| {
3043 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3044 thread.send(UserMessageId::new(), ["Hello!"], cx)
3045 })
3046 .unwrap();
3047 cx.run_until_parked();
3048
3049 fake_model.send_last_completion_stream_text_chunk("Hey!");
3050 fake_model.end_last_completion_stream();
3051
3052 let mut retry_events = Vec::new();
3053 while let Some(Ok(event)) = events.next().await {
3054 match event {
3055 ThreadEvent::Retry(retry_status) => {
3056 retry_events.push(retry_status);
3057 }
3058 ThreadEvent::Stop(..) => break,
3059 _ => {}
3060 }
3061 }
3062
3063 assert_eq!(retry_events.len(), 0);
3064 thread.read_with(cx, |thread, _cx| {
3065 assert_eq!(
3066 thread.to_markdown(),
3067 indoc! {"
3068 ## User
3069
3070 Hello!
3071
3072 ## Assistant
3073
3074 Hey!
3075 "}
3076 )
3077 });
3078}
3079
3080#[gpui::test]
3081async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3082 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3083 let fake_model = model.as_fake();
3084
3085 let mut events = thread
3086 .update(cx, |thread, cx| {
3087 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3088 thread.send(UserMessageId::new(), ["Hello!"], cx)
3089 })
3090 .unwrap();
3091 cx.run_until_parked();
3092
3093 fake_model.send_last_completion_stream_text_chunk("Hey,");
3094 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3095 provider: LanguageModelProviderName::new("Anthropic"),
3096 retry_after: Some(Duration::from_secs(3)),
3097 });
3098 fake_model.end_last_completion_stream();
3099
3100 cx.executor().advance_clock(Duration::from_secs(3));
3101 cx.run_until_parked();
3102
3103 fake_model.send_last_completion_stream_text_chunk("there!");
3104 fake_model.end_last_completion_stream();
3105 cx.run_until_parked();
3106
3107 let mut retry_events = Vec::new();
3108 while let Some(Ok(event)) = events.next().await {
3109 match event {
3110 ThreadEvent::Retry(retry_status) => {
3111 retry_events.push(retry_status);
3112 }
3113 ThreadEvent::Stop(..) => break,
3114 _ => {}
3115 }
3116 }
3117
3118 assert_eq!(retry_events.len(), 1);
3119 assert!(matches!(
3120 retry_events[0],
3121 acp_thread::RetryStatus { attempt: 1, .. }
3122 ));
3123 thread.read_with(cx, |thread, _cx| {
3124 assert_eq!(
3125 thread.to_markdown(),
3126 indoc! {"
3127 ## User
3128
3129 Hello!
3130
3131 ## Assistant
3132
3133 Hey,
3134
3135 [resume]
3136
3137 ## Assistant
3138
3139 there!
3140 "}
3141 )
3142 });
3143}
3144
3145#[gpui::test]
3146async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3147 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3148 let fake_model = model.as_fake();
3149
3150 let events = thread
3151 .update(cx, |thread, cx| {
3152 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3153 thread.add_tool(EchoTool);
3154 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3155 })
3156 .unwrap();
3157 cx.run_until_parked();
3158
3159 let tool_use_1 = LanguageModelToolUse {
3160 id: "tool_1".into(),
3161 name: EchoTool::name().into(),
3162 raw_input: json!({"text": "test"}).to_string(),
3163 input: json!({"text": "test"}),
3164 is_input_complete: true,
3165 thought_signature: None,
3166 };
3167 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3168 tool_use_1.clone(),
3169 ));
3170 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3171 provider: LanguageModelProviderName::new("Anthropic"),
3172 retry_after: Some(Duration::from_secs(3)),
3173 });
3174 fake_model.end_last_completion_stream();
3175
3176 cx.executor().advance_clock(Duration::from_secs(3));
3177 let completion = fake_model.pending_completions().pop().unwrap();
3178 assert_eq!(
3179 completion.messages[1..],
3180 vec![
3181 LanguageModelRequestMessage {
3182 role: Role::User,
3183 content: vec!["Call the echo tool!".into()],
3184 cache: false,
3185 reasoning_details: None,
3186 },
3187 LanguageModelRequestMessage {
3188 role: Role::Assistant,
3189 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3190 cache: false,
3191 reasoning_details: None,
3192 },
3193 LanguageModelRequestMessage {
3194 role: Role::User,
3195 content: vec![language_model::MessageContent::ToolResult(
3196 LanguageModelToolResult {
3197 tool_use_id: tool_use_1.id.clone(),
3198 tool_name: tool_use_1.name.clone(),
3199 is_error: false,
3200 content: "test".into(),
3201 output: Some("test".into())
3202 }
3203 )],
3204 cache: true,
3205 reasoning_details: None,
3206 },
3207 ]
3208 );
3209
3210 fake_model.send_last_completion_stream_text_chunk("Done");
3211 fake_model.end_last_completion_stream();
3212 cx.run_until_parked();
3213 events.collect::<Vec<_>>().await;
3214 thread.read_with(cx, |thread, _cx| {
3215 assert_eq!(
3216 thread.last_message(),
3217 Some(Message::Agent(AgentMessage {
3218 content: vec![AgentMessageContent::Text("Done".into())],
3219 tool_results: IndexMap::default(),
3220 reasoning_details: None,
3221 }))
3222 );
3223 })
3224}
3225
3226#[gpui::test]
3227async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3228 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3229 let fake_model = model.as_fake();
3230
3231 let mut events = thread
3232 .update(cx, |thread, cx| {
3233 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3234 thread.send(UserMessageId::new(), ["Hello!"], cx)
3235 })
3236 .unwrap();
3237 cx.run_until_parked();
3238
3239 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3240 fake_model.send_last_completion_stream_error(
3241 LanguageModelCompletionError::ServerOverloaded {
3242 provider: LanguageModelProviderName::new("Anthropic"),
3243 retry_after: Some(Duration::from_secs(3)),
3244 },
3245 );
3246 fake_model.end_last_completion_stream();
3247 cx.executor().advance_clock(Duration::from_secs(3));
3248 cx.run_until_parked();
3249 }
3250
3251 let mut errors = Vec::new();
3252 let mut retry_events = Vec::new();
3253 while let Some(event) = events.next().await {
3254 match event {
3255 Ok(ThreadEvent::Retry(retry_status)) => {
3256 retry_events.push(retry_status);
3257 }
3258 Ok(ThreadEvent::Stop(..)) => break,
3259 Err(error) => errors.push(error),
3260 _ => {}
3261 }
3262 }
3263
3264 assert_eq!(
3265 retry_events.len(),
3266 crate::thread::MAX_RETRY_ATTEMPTS as usize
3267 );
3268 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3269 assert_eq!(retry_events[i].attempt, i + 1);
3270 }
3271 assert_eq!(errors.len(), 1);
3272 let error = errors[0]
3273 .downcast_ref::<LanguageModelCompletionError>()
3274 .unwrap();
3275 assert!(matches!(
3276 error,
3277 LanguageModelCompletionError::ServerOverloaded { .. }
3278 ));
3279}
3280
3281/// Filters out the stop events for asserting against in tests
3282fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3283 result_events
3284 .into_iter()
3285 .filter_map(|event| match event.unwrap() {
3286 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3287 _ => None,
3288 })
3289 .collect()
3290}
3291
3292struct ThreadTest {
3293 model: Arc<dyn LanguageModel>,
3294 thread: Entity<Thread>,
3295 project_context: Entity<ProjectContext>,
3296 context_server_store: Entity<ContextServerStore>,
3297 fs: Arc<FakeFs>,
3298}
3299
3300enum TestModel {
3301 Sonnet4,
3302 Fake,
3303}
3304
3305impl TestModel {
3306 fn id(&self) -> LanguageModelId {
3307 match self {
3308 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3309 TestModel::Fake => unreachable!(),
3310 }
3311 }
3312}
3313
3314async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3315 cx.executor().allow_parking();
3316
3317 let fs = FakeFs::new(cx.background_executor.clone());
3318 fs.create_dir(paths::settings_file().parent().unwrap())
3319 .await
3320 .unwrap();
3321 fs.insert_file(
3322 paths::settings_file(),
3323 json!({
3324 "agent": {
3325 "default_profile": "test-profile",
3326 "profiles": {
3327 "test-profile": {
3328 "name": "Test Profile",
3329 "tools": {
3330 EchoTool::name(): true,
3331 DelayTool::name(): true,
3332 WordListTool::name(): true,
3333 ToolRequiringPermission::name(): true,
3334 InfiniteTool::name(): true,
3335 CancellationAwareTool::name(): true,
3336 ThinkingTool::name(): true,
3337 "terminal": true,
3338 }
3339 }
3340 }
3341 }
3342 })
3343 .to_string()
3344 .into_bytes(),
3345 )
3346 .await;
3347
3348 cx.update(|cx| {
3349 settings::init(cx);
3350
3351 match model {
3352 TestModel::Fake => {}
3353 TestModel::Sonnet4 => {
3354 gpui_tokio::init(cx);
3355 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3356 cx.set_http_client(Arc::new(http_client));
3357 let client = Client::production(cx);
3358 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3359 language_model::init(client.clone(), cx);
3360 language_models::init(user_store, client.clone(), cx);
3361 }
3362 };
3363
3364 watch_settings(fs.clone(), cx);
3365 });
3366
3367 let templates = Templates::new();
3368
3369 fs.insert_tree(path!("/test"), json!({})).await;
3370 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3371
3372 let model = cx
3373 .update(|cx| {
3374 if let TestModel::Fake = model {
3375 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3376 } else {
3377 let model_id = model.id();
3378 let models = LanguageModelRegistry::read_global(cx);
3379 let model = models
3380 .available_models(cx)
3381 .find(|model| model.id() == model_id)
3382 .unwrap();
3383
3384 let provider = models.provider(&model.provider_id()).unwrap();
3385 let authenticated = provider.authenticate(cx);
3386
3387 cx.spawn(async move |_cx| {
3388 authenticated.await.unwrap();
3389 model
3390 })
3391 }
3392 })
3393 .await;
3394
3395 let project_context = cx.new(|_cx| ProjectContext::default());
3396 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3397 let context_server_registry =
3398 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3399 let thread = cx.new(|cx| {
3400 Thread::new(
3401 project,
3402 project_context.clone(),
3403 context_server_registry,
3404 templates,
3405 Some(model.clone()),
3406 cx,
3407 )
3408 });
3409 ThreadTest {
3410 model,
3411 thread,
3412 project_context,
3413 context_server_store,
3414 fs,
3415 }
3416}
3417
3418#[cfg(test)]
3419#[ctor::ctor]
3420fn init_logger() {
3421 if std::env::var("RUST_LOG").is_ok() {
3422 env_logger::init();
3423 }
3424}
3425
3426fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3427 let fs = fs.clone();
3428 cx.spawn({
3429 async move |cx| {
3430 let mut new_settings_content_rx = settings::watch_config_file(
3431 cx.background_executor(),
3432 fs,
3433 paths::settings_file().clone(),
3434 );
3435
3436 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3437 cx.update(|cx| {
3438 SettingsStore::update_global(cx, |settings, cx| {
3439 settings.set_user_settings(&new_settings_content, cx)
3440 })
3441 })
3442 .ok();
3443 }
3444 }
3445 })
3446 .detach();
3447}
3448
3449fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3450 completion
3451 .tools
3452 .iter()
3453 .map(|tool| tool.name.clone())
3454 .collect()
3455}
3456
3457fn setup_context_server(
3458 name: &'static str,
3459 tools: Vec<context_server::types::Tool>,
3460 context_server_store: &Entity<ContextServerStore>,
3461 cx: &mut TestAppContext,
3462) -> mpsc::UnboundedReceiver<(
3463 context_server::types::CallToolParams,
3464 oneshot::Sender<context_server::types::CallToolResponse>,
3465)> {
3466 cx.update(|cx| {
3467 let mut settings = ProjectSettings::get_global(cx).clone();
3468 settings.context_servers.insert(
3469 name.into(),
3470 project::project_settings::ContextServerSettings::Stdio {
3471 enabled: true,
3472 command: ContextServerCommand {
3473 path: "somebinary".into(),
3474 args: Vec::new(),
3475 env: None,
3476 timeout: None,
3477 },
3478 },
3479 );
3480 ProjectSettings::override_global(settings, cx);
3481 });
3482
3483 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3484 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3485 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3486 context_server::types::InitializeResponse {
3487 protocol_version: context_server::types::ProtocolVersion(
3488 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3489 ),
3490 server_info: context_server::types::Implementation {
3491 name: name.into(),
3492 version: "1.0.0".to_string(),
3493 },
3494 capabilities: context_server::types::ServerCapabilities {
3495 tools: Some(context_server::types::ToolsCapabilities {
3496 list_changed: Some(true),
3497 }),
3498 ..Default::default()
3499 },
3500 meta: None,
3501 }
3502 })
3503 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3504 let tools = tools.clone();
3505 async move {
3506 context_server::types::ListToolsResponse {
3507 tools,
3508 next_cursor: None,
3509 meta: None,
3510 }
3511 }
3512 })
3513 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3514 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3515 async move {
3516 let (response_tx, response_rx) = oneshot::channel();
3517 mcp_tool_calls_tx
3518 .unbounded_send((params, response_tx))
3519 .unwrap();
3520 response_rx.await.unwrap()
3521 }
3522 });
3523 context_server_store.update(cx, |store, cx| {
3524 store.start_server(
3525 Arc::new(ContextServer::new(
3526 ContextServerId(name.into()),
3527 Arc::new(fake_transport),
3528 )),
3529 cx,
3530 );
3531 });
3532 cx.run_until_parked();
3533 mcp_tool_calls_rx
3534}
3535
3536#[gpui::test]
3537async fn test_tokens_before_message(cx: &mut TestAppContext) {
3538 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3539 let fake_model = model.as_fake();
3540
3541 // First message
3542 let message_1_id = UserMessageId::new();
3543 thread
3544 .update(cx, |thread, cx| {
3545 thread.send(message_1_id.clone(), ["First message"], cx)
3546 })
3547 .unwrap();
3548 cx.run_until_parked();
3549
3550 // Before any response, tokens_before_message should return None for first message
3551 thread.read_with(cx, |thread, _| {
3552 assert_eq!(
3553 thread.tokens_before_message(&message_1_id),
3554 None,
3555 "First message should have no tokens before it"
3556 );
3557 });
3558
3559 // Complete first message with usage
3560 fake_model.send_last_completion_stream_text_chunk("Response 1");
3561 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3562 language_model::TokenUsage {
3563 input_tokens: 100,
3564 output_tokens: 50,
3565 cache_creation_input_tokens: 0,
3566 cache_read_input_tokens: 0,
3567 },
3568 ));
3569 fake_model.end_last_completion_stream();
3570 cx.run_until_parked();
3571
3572 // First message still has no tokens before it
3573 thread.read_with(cx, |thread, _| {
3574 assert_eq!(
3575 thread.tokens_before_message(&message_1_id),
3576 None,
3577 "First message should still have no tokens before it after response"
3578 );
3579 });
3580
3581 // Second message
3582 let message_2_id = UserMessageId::new();
3583 thread
3584 .update(cx, |thread, cx| {
3585 thread.send(message_2_id.clone(), ["Second message"], cx)
3586 })
3587 .unwrap();
3588 cx.run_until_parked();
3589
3590 // Second message should have first message's input tokens before it
3591 thread.read_with(cx, |thread, _| {
3592 assert_eq!(
3593 thread.tokens_before_message(&message_2_id),
3594 Some(100),
3595 "Second message should have 100 tokens before it (from first request)"
3596 );
3597 });
3598
3599 // Complete second message
3600 fake_model.send_last_completion_stream_text_chunk("Response 2");
3601 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3602 language_model::TokenUsage {
3603 input_tokens: 250, // Total for this request (includes previous context)
3604 output_tokens: 75,
3605 cache_creation_input_tokens: 0,
3606 cache_read_input_tokens: 0,
3607 },
3608 ));
3609 fake_model.end_last_completion_stream();
3610 cx.run_until_parked();
3611
3612 // Third message
3613 let message_3_id = UserMessageId::new();
3614 thread
3615 .update(cx, |thread, cx| {
3616 thread.send(message_3_id.clone(), ["Third message"], cx)
3617 })
3618 .unwrap();
3619 cx.run_until_parked();
3620
3621 // Third message should have second message's input tokens (250) before it
3622 thread.read_with(cx, |thread, _| {
3623 assert_eq!(
3624 thread.tokens_before_message(&message_3_id),
3625 Some(250),
3626 "Third message should have 250 tokens before it (from second request)"
3627 );
3628 // Second message should still have 100
3629 assert_eq!(
3630 thread.tokens_before_message(&message_2_id),
3631 Some(100),
3632 "Second message should still have 100 tokens before it"
3633 );
3634 // First message still has none
3635 assert_eq!(
3636 thread.tokens_before_message(&message_1_id),
3637 None,
3638 "First message should still have no tokens before it"
3639 );
3640 });
3641}
3642
3643#[gpui::test]
3644async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3645 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3646 let fake_model = model.as_fake();
3647
3648 // Set up three messages with responses
3649 let message_1_id = UserMessageId::new();
3650 thread
3651 .update(cx, |thread, cx| {
3652 thread.send(message_1_id.clone(), ["Message 1"], cx)
3653 })
3654 .unwrap();
3655 cx.run_until_parked();
3656 fake_model.send_last_completion_stream_text_chunk("Response 1");
3657 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3658 language_model::TokenUsage {
3659 input_tokens: 100,
3660 output_tokens: 50,
3661 cache_creation_input_tokens: 0,
3662 cache_read_input_tokens: 0,
3663 },
3664 ));
3665 fake_model.end_last_completion_stream();
3666 cx.run_until_parked();
3667
3668 let message_2_id = UserMessageId::new();
3669 thread
3670 .update(cx, |thread, cx| {
3671 thread.send(message_2_id.clone(), ["Message 2"], cx)
3672 })
3673 .unwrap();
3674 cx.run_until_parked();
3675 fake_model.send_last_completion_stream_text_chunk("Response 2");
3676 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3677 language_model::TokenUsage {
3678 input_tokens: 250,
3679 output_tokens: 75,
3680 cache_creation_input_tokens: 0,
3681 cache_read_input_tokens: 0,
3682 },
3683 ));
3684 fake_model.end_last_completion_stream();
3685 cx.run_until_parked();
3686
3687 // Verify initial state
3688 thread.read_with(cx, |thread, _| {
3689 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3690 });
3691
3692 // Truncate at message 2 (removes message 2 and everything after)
3693 thread
3694 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3695 .unwrap();
3696 cx.run_until_parked();
3697
3698 // After truncation, message_2_id no longer exists, so lookup should return None
3699 thread.read_with(cx, |thread, _| {
3700 assert_eq!(
3701 thread.tokens_before_message(&message_2_id),
3702 None,
3703 "After truncation, message 2 no longer exists"
3704 );
3705 // Message 1 still exists but has no tokens before it
3706 assert_eq!(
3707 thread.tokens_before_message(&message_1_id),
3708 None,
3709 "First message still has no tokens before it"
3710 );
3711 });
3712}
3713
3714#[gpui::test]
3715async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3716 init_test(cx);
3717
3718 let fs = FakeFs::new(cx.executor());
3719 fs.insert_tree("/root", json!({})).await;
3720 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3721
3722 // Test 1: Deny rule blocks command
3723 {
3724 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3725 let environment = Rc::new(FakeThreadEnvironment {
3726 handle: handle.clone(),
3727 });
3728
3729 cx.update(|cx| {
3730 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3731 settings.tool_permissions.tools.insert(
3732 "terminal".into(),
3733 agent_settings::ToolRules {
3734 default_mode: settings::ToolPermissionMode::Confirm,
3735 always_allow: vec![],
3736 always_deny: vec![
3737 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3738 ],
3739 always_confirm: vec![],
3740 invalid_patterns: vec![],
3741 },
3742 );
3743 agent_settings::AgentSettings::override_global(settings, cx);
3744 });
3745
3746 #[allow(clippy::arc_with_non_send_sync)]
3747 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3748 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3749
3750 let task = cx.update(|cx| {
3751 tool.run(
3752 crate::TerminalToolInput {
3753 command: "rm -rf /".to_string(),
3754 cd: ".".to_string(),
3755 timeout_ms: None,
3756 },
3757 event_stream,
3758 cx,
3759 )
3760 });
3761
3762 let result = task.await;
3763 assert!(
3764 result.is_err(),
3765 "expected command to be blocked by deny rule"
3766 );
3767 assert!(
3768 result.unwrap_err().to_string().contains("blocked"),
3769 "error should mention the command was blocked"
3770 );
3771 }
3772
3773 // Test 2: Allow rule skips confirmation (and overrides default_mode: Deny)
3774 {
3775 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3776 let environment = Rc::new(FakeThreadEnvironment {
3777 handle: handle.clone(),
3778 });
3779
3780 cx.update(|cx| {
3781 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3782 settings.always_allow_tool_actions = false;
3783 settings.tool_permissions.tools.insert(
3784 "terminal".into(),
3785 agent_settings::ToolRules {
3786 default_mode: settings::ToolPermissionMode::Deny,
3787 always_allow: vec![
3788 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
3789 ],
3790 always_deny: vec![],
3791 always_confirm: vec![],
3792 invalid_patterns: vec![],
3793 },
3794 );
3795 agent_settings::AgentSettings::override_global(settings, cx);
3796 });
3797
3798 #[allow(clippy::arc_with_non_send_sync)]
3799 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3800 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3801
3802 let task = cx.update(|cx| {
3803 tool.run(
3804 crate::TerminalToolInput {
3805 command: "echo hello".to_string(),
3806 cd: ".".to_string(),
3807 timeout_ms: None,
3808 },
3809 event_stream,
3810 cx,
3811 )
3812 });
3813
3814 let update = rx.expect_update_fields().await;
3815 assert!(
3816 update.content.iter().any(|blocks| {
3817 blocks
3818 .iter()
3819 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
3820 }),
3821 "expected terminal content (allow rule should skip confirmation and override default deny)"
3822 );
3823
3824 let result = task.await;
3825 assert!(
3826 result.is_ok(),
3827 "expected command to succeed without confirmation"
3828 );
3829 }
3830
3831 // Test 3: Confirm rule forces confirmation even with always_allow_tool_actions=true
3832 {
3833 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3834 let environment = Rc::new(FakeThreadEnvironment {
3835 handle: handle.clone(),
3836 });
3837
3838 cx.update(|cx| {
3839 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3840 settings.always_allow_tool_actions = true;
3841 settings.tool_permissions.tools.insert(
3842 "terminal".into(),
3843 agent_settings::ToolRules {
3844 default_mode: settings::ToolPermissionMode::Allow,
3845 always_allow: vec![],
3846 always_deny: vec![],
3847 always_confirm: vec![
3848 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
3849 ],
3850 invalid_patterns: vec![],
3851 },
3852 );
3853 agent_settings::AgentSettings::override_global(settings, cx);
3854 });
3855
3856 #[allow(clippy::arc_with_non_send_sync)]
3857 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3858 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3859
3860 let _task = cx.update(|cx| {
3861 tool.run(
3862 crate::TerminalToolInput {
3863 command: "sudo rm file".to_string(),
3864 cd: ".".to_string(),
3865 timeout_ms: None,
3866 },
3867 event_stream,
3868 cx,
3869 )
3870 });
3871
3872 let auth = rx.expect_authorization().await;
3873 assert!(
3874 auth.tool_call.fields.title.is_some(),
3875 "expected authorization request for sudo command despite always_allow_tool_actions=true"
3876 );
3877 }
3878
3879 // Test 4: default_mode: Deny blocks commands when no pattern matches
3880 {
3881 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3882 let environment = Rc::new(FakeThreadEnvironment {
3883 handle: handle.clone(),
3884 });
3885
3886 cx.update(|cx| {
3887 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3888 settings.always_allow_tool_actions = true;
3889 settings.tool_permissions.tools.insert(
3890 "terminal".into(),
3891 agent_settings::ToolRules {
3892 default_mode: settings::ToolPermissionMode::Deny,
3893 always_allow: vec![],
3894 always_deny: vec![],
3895 always_confirm: vec![],
3896 invalid_patterns: vec![],
3897 },
3898 );
3899 agent_settings::AgentSettings::override_global(settings, cx);
3900 });
3901
3902 #[allow(clippy::arc_with_non_send_sync)]
3903 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3904 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3905
3906 let task = cx.update(|cx| {
3907 tool.run(
3908 crate::TerminalToolInput {
3909 command: "echo hello".to_string(),
3910 cd: ".".to_string(),
3911 timeout_ms: None,
3912 },
3913 event_stream,
3914 cx,
3915 )
3916 });
3917
3918 let result = task.await;
3919 assert!(
3920 result.is_err(),
3921 "expected command to be blocked by default_mode: Deny"
3922 );
3923 assert!(
3924 result.unwrap_err().to_string().contains("disabled"),
3925 "error should mention the tool is disabled"
3926 );
3927 }
3928}
3929
3930#[gpui::test]
3931async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
3932 init_test(cx);
3933
3934 let fs = FakeFs::new(cx.executor());
3935 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
3936 .await;
3937 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
3938
3939 cx.update(|cx| {
3940 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3941 settings.tool_permissions.tools.insert(
3942 "edit_file".into(),
3943 agent_settings::ToolRules {
3944 default_mode: settings::ToolPermissionMode::Allow,
3945 always_allow: vec![],
3946 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
3947 always_confirm: vec![],
3948 invalid_patterns: vec![],
3949 },
3950 );
3951 agent_settings::AgentSettings::override_global(settings, cx);
3952 });
3953
3954 let context_server_registry =
3955 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
3956 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
3957 let templates = crate::Templates::new();
3958 let thread = cx.new(|cx| {
3959 crate::Thread::new(
3960 project.clone(),
3961 cx.new(|_cx| prompt_store::ProjectContext::default()),
3962 context_server_registry,
3963 templates.clone(),
3964 None,
3965 cx,
3966 )
3967 });
3968
3969 #[allow(clippy::arc_with_non_send_sync)]
3970 let tool = Arc::new(crate::EditFileTool::new(
3971 project.clone(),
3972 thread.downgrade(),
3973 language_registry,
3974 templates,
3975 ));
3976 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3977
3978 let task = cx.update(|cx| {
3979 tool.run(
3980 crate::EditFileToolInput {
3981 display_description: "Edit sensitive file".to_string(),
3982 path: "root/sensitive_config.txt".into(),
3983 mode: crate::EditFileMode::Edit,
3984 },
3985 event_stream,
3986 cx,
3987 )
3988 });
3989
3990 let result = task.await;
3991 assert!(result.is_err(), "expected edit to be blocked");
3992 assert!(
3993 result.unwrap_err().to_string().contains("blocked"),
3994 "error should mention the edit was blocked"
3995 );
3996}
3997
3998#[gpui::test]
3999async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
4000 init_test(cx);
4001
4002 let fs = FakeFs::new(cx.executor());
4003 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
4004 .await;
4005 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4006
4007 cx.update(|cx| {
4008 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4009 settings.tool_permissions.tools.insert(
4010 "delete_path".into(),
4011 agent_settings::ToolRules {
4012 default_mode: settings::ToolPermissionMode::Allow,
4013 always_allow: vec![],
4014 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
4015 always_confirm: vec![],
4016 invalid_patterns: vec![],
4017 },
4018 );
4019 agent_settings::AgentSettings::override_global(settings, cx);
4020 });
4021
4022 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
4023
4024 #[allow(clippy::arc_with_non_send_sync)]
4025 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
4026 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4027
4028 let task = cx.update(|cx| {
4029 tool.run(
4030 crate::DeletePathToolInput {
4031 path: "root/important_data.txt".to_string(),
4032 },
4033 event_stream,
4034 cx,
4035 )
4036 });
4037
4038 let result = task.await;
4039 assert!(result.is_err(), "expected deletion to be blocked");
4040 assert!(
4041 result.unwrap_err().to_string().contains("blocked"),
4042 "error should mention the deletion was blocked"
4043 );
4044}
4045
4046#[gpui::test]
4047async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
4048 init_test(cx);
4049
4050 let fs = FakeFs::new(cx.executor());
4051 fs.insert_tree(
4052 "/root",
4053 json!({
4054 "safe.txt": "content",
4055 "protected": {}
4056 }),
4057 )
4058 .await;
4059 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4060
4061 cx.update(|cx| {
4062 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4063 settings.tool_permissions.tools.insert(
4064 "move_path".into(),
4065 agent_settings::ToolRules {
4066 default_mode: settings::ToolPermissionMode::Allow,
4067 always_allow: vec![],
4068 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
4069 always_confirm: vec![],
4070 invalid_patterns: vec![],
4071 },
4072 );
4073 agent_settings::AgentSettings::override_global(settings, cx);
4074 });
4075
4076 #[allow(clippy::arc_with_non_send_sync)]
4077 let tool = Arc::new(crate::MovePathTool::new(project));
4078 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4079
4080 let task = cx.update(|cx| {
4081 tool.run(
4082 crate::MovePathToolInput {
4083 source_path: "root/safe.txt".to_string(),
4084 destination_path: "root/protected/safe.txt".to_string(),
4085 },
4086 event_stream,
4087 cx,
4088 )
4089 });
4090
4091 let result = task.await;
4092 assert!(
4093 result.is_err(),
4094 "expected move to be blocked due to destination path"
4095 );
4096 assert!(
4097 result.unwrap_err().to_string().contains("blocked"),
4098 "error should mention the move was blocked"
4099 );
4100}
4101
4102#[gpui::test]
4103async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
4104 init_test(cx);
4105
4106 let fs = FakeFs::new(cx.executor());
4107 fs.insert_tree(
4108 "/root",
4109 json!({
4110 "secret.txt": "secret content",
4111 "public": {}
4112 }),
4113 )
4114 .await;
4115 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4116
4117 cx.update(|cx| {
4118 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4119 settings.tool_permissions.tools.insert(
4120 "move_path".into(),
4121 agent_settings::ToolRules {
4122 default_mode: settings::ToolPermissionMode::Allow,
4123 always_allow: vec![],
4124 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
4125 always_confirm: vec![],
4126 invalid_patterns: vec![],
4127 },
4128 );
4129 agent_settings::AgentSettings::override_global(settings, cx);
4130 });
4131
4132 #[allow(clippy::arc_with_non_send_sync)]
4133 let tool = Arc::new(crate::MovePathTool::new(project));
4134 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4135
4136 let task = cx.update(|cx| {
4137 tool.run(
4138 crate::MovePathToolInput {
4139 source_path: "root/secret.txt".to_string(),
4140 destination_path: "root/public/not_secret.txt".to_string(),
4141 },
4142 event_stream,
4143 cx,
4144 )
4145 });
4146
4147 let result = task.await;
4148 assert!(
4149 result.is_err(),
4150 "expected move to be blocked due to source path"
4151 );
4152 assert!(
4153 result.unwrap_err().to_string().contains("blocked"),
4154 "error should mention the move was blocked"
4155 );
4156}
4157
4158#[gpui::test]
4159async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
4160 init_test(cx);
4161
4162 let fs = FakeFs::new(cx.executor());
4163 fs.insert_tree(
4164 "/root",
4165 json!({
4166 "confidential.txt": "confidential data",
4167 "dest": {}
4168 }),
4169 )
4170 .await;
4171 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4172
4173 cx.update(|cx| {
4174 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4175 settings.tool_permissions.tools.insert(
4176 "copy_path".into(),
4177 agent_settings::ToolRules {
4178 default_mode: settings::ToolPermissionMode::Allow,
4179 always_allow: vec![],
4180 always_deny: vec![
4181 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
4182 ],
4183 always_confirm: vec![],
4184 invalid_patterns: vec![],
4185 },
4186 );
4187 agent_settings::AgentSettings::override_global(settings, cx);
4188 });
4189
4190 #[allow(clippy::arc_with_non_send_sync)]
4191 let tool = Arc::new(crate::CopyPathTool::new(project));
4192 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4193
4194 let task = cx.update(|cx| {
4195 tool.run(
4196 crate::CopyPathToolInput {
4197 source_path: "root/confidential.txt".to_string(),
4198 destination_path: "root/dest/copy.txt".to_string(),
4199 },
4200 event_stream,
4201 cx,
4202 )
4203 });
4204
4205 let result = task.await;
4206 assert!(result.is_err(), "expected copy to be blocked");
4207 assert!(
4208 result.unwrap_err().to_string().contains("blocked"),
4209 "error should mention the copy was blocked"
4210 );
4211}
4212
4213#[gpui::test]
4214async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
4215 init_test(cx);
4216
4217 let fs = FakeFs::new(cx.executor());
4218 fs.insert_tree(
4219 "/root",
4220 json!({
4221 "normal.txt": "normal content",
4222 "readonly": {
4223 "config.txt": "readonly content"
4224 }
4225 }),
4226 )
4227 .await;
4228 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4229
4230 cx.update(|cx| {
4231 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4232 settings.tool_permissions.tools.insert(
4233 "save_file".into(),
4234 agent_settings::ToolRules {
4235 default_mode: settings::ToolPermissionMode::Allow,
4236 always_allow: vec![],
4237 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
4238 always_confirm: vec![],
4239 invalid_patterns: vec![],
4240 },
4241 );
4242 agent_settings::AgentSettings::override_global(settings, cx);
4243 });
4244
4245 #[allow(clippy::arc_with_non_send_sync)]
4246 let tool = Arc::new(crate::SaveFileTool::new(project));
4247 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4248
4249 let task = cx.update(|cx| {
4250 tool.run(
4251 crate::SaveFileToolInput {
4252 paths: vec![
4253 std::path::PathBuf::from("root/normal.txt"),
4254 std::path::PathBuf::from("root/readonly/config.txt"),
4255 ],
4256 },
4257 event_stream,
4258 cx,
4259 )
4260 });
4261
4262 let result = task.await;
4263 assert!(
4264 result.is_err(),
4265 "expected save to be blocked due to denied path"
4266 );
4267 assert!(
4268 result.unwrap_err().to_string().contains("blocked"),
4269 "error should mention the save was blocked"
4270 );
4271}
4272
4273#[gpui::test]
4274async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
4275 init_test(cx);
4276
4277 let fs = FakeFs::new(cx.executor());
4278 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
4279 .await;
4280 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4281
4282 cx.update(|cx| {
4283 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4284 settings.always_allow_tool_actions = false;
4285 settings.tool_permissions.tools.insert(
4286 "save_file".into(),
4287 agent_settings::ToolRules {
4288 default_mode: settings::ToolPermissionMode::Allow,
4289 always_allow: vec![],
4290 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
4291 always_confirm: vec![],
4292 invalid_patterns: vec![],
4293 },
4294 );
4295 agent_settings::AgentSettings::override_global(settings, cx);
4296 });
4297
4298 #[allow(clippy::arc_with_non_send_sync)]
4299 let tool = Arc::new(crate::SaveFileTool::new(project));
4300 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4301
4302 let task = cx.update(|cx| {
4303 tool.run(
4304 crate::SaveFileToolInput {
4305 paths: vec![std::path::PathBuf::from("root/config.secret")],
4306 },
4307 event_stream,
4308 cx,
4309 )
4310 });
4311
4312 let result = task.await;
4313 assert!(result.is_err(), "expected save to be blocked");
4314 assert!(
4315 result.unwrap_err().to_string().contains("blocked"),
4316 "error should mention the save was blocked"
4317 );
4318}
4319
4320#[gpui::test]
4321async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
4322 init_test(cx);
4323
4324 cx.update(|cx| {
4325 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4326 settings.tool_permissions.tools.insert(
4327 "web_search".into(),
4328 agent_settings::ToolRules {
4329 default_mode: settings::ToolPermissionMode::Allow,
4330 always_allow: vec![],
4331 always_deny: vec![
4332 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
4333 ],
4334 always_confirm: vec![],
4335 invalid_patterns: vec![],
4336 },
4337 );
4338 agent_settings::AgentSettings::override_global(settings, cx);
4339 });
4340
4341 #[allow(clippy::arc_with_non_send_sync)]
4342 let tool = Arc::new(crate::WebSearchTool);
4343 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4344
4345 let input: crate::WebSearchToolInput =
4346 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
4347
4348 let task = cx.update(|cx| tool.run(input, event_stream, cx));
4349
4350 let result = task.await;
4351 assert!(result.is_err(), "expected search to be blocked");
4352 assert!(
4353 result.unwrap_err().to_string().contains("blocked"),
4354 "error should mention the search was blocked"
4355 );
4356}
4357
4358#[gpui::test]
4359async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
4360 init_test(cx);
4361
4362 let fs = FakeFs::new(cx.executor());
4363 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
4364 .await;
4365 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4366
4367 cx.update(|cx| {
4368 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4369 settings.always_allow_tool_actions = false;
4370 settings.tool_permissions.tools.insert(
4371 "edit_file".into(),
4372 agent_settings::ToolRules {
4373 default_mode: settings::ToolPermissionMode::Confirm,
4374 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
4375 always_deny: vec![],
4376 always_confirm: vec![],
4377 invalid_patterns: vec![],
4378 },
4379 );
4380 agent_settings::AgentSettings::override_global(settings, cx);
4381 });
4382
4383 let context_server_registry =
4384 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
4385 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
4386 let templates = crate::Templates::new();
4387 let thread = cx.new(|cx| {
4388 crate::Thread::new(
4389 project.clone(),
4390 cx.new(|_cx| prompt_store::ProjectContext::default()),
4391 context_server_registry,
4392 templates.clone(),
4393 None,
4394 cx,
4395 )
4396 });
4397
4398 #[allow(clippy::arc_with_non_send_sync)]
4399 let tool = Arc::new(crate::EditFileTool::new(
4400 project,
4401 thread.downgrade(),
4402 language_registry,
4403 templates,
4404 ));
4405 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4406
4407 let _task = cx.update(|cx| {
4408 tool.run(
4409 crate::EditFileToolInput {
4410 display_description: "Edit README".to_string(),
4411 path: "root/README.md".into(),
4412 mode: crate::EditFileMode::Edit,
4413 },
4414 event_stream,
4415 cx,
4416 )
4417 });
4418
4419 cx.run_until_parked();
4420
4421 let event = rx.try_next();
4422 assert!(
4423 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
4424 "expected no authorization request for allowed .md file"
4425 );
4426}
4427
4428#[gpui::test]
4429async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
4430 init_test(cx);
4431
4432 cx.update(|cx| {
4433 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4434 settings.tool_permissions.tools.insert(
4435 "fetch".into(),
4436 agent_settings::ToolRules {
4437 default_mode: settings::ToolPermissionMode::Allow,
4438 always_allow: vec![],
4439 always_deny: vec![
4440 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
4441 ],
4442 always_confirm: vec![],
4443 invalid_patterns: vec![],
4444 },
4445 );
4446 agent_settings::AgentSettings::override_global(settings, cx);
4447 });
4448
4449 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
4450
4451 #[allow(clippy::arc_with_non_send_sync)]
4452 let tool = Arc::new(crate::FetchTool::new(http_client));
4453 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4454
4455 let input: crate::FetchToolInput =
4456 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
4457
4458 let task = cx.update(|cx| tool.run(input, event_stream, cx));
4459
4460 let result = task.await;
4461 assert!(result.is_err(), "expected fetch to be blocked");
4462 assert!(
4463 result.unwrap_err().to_string().contains("blocked"),
4464 "error should mention the fetch was blocked"
4465 );
4466}
4467
4468#[gpui::test]
4469async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
4470 init_test(cx);
4471
4472 cx.update(|cx| {
4473 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4474 settings.always_allow_tool_actions = false;
4475 settings.tool_permissions.tools.insert(
4476 "fetch".into(),
4477 agent_settings::ToolRules {
4478 default_mode: settings::ToolPermissionMode::Confirm,
4479 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
4480 always_deny: vec![],
4481 always_confirm: vec![],
4482 invalid_patterns: vec![],
4483 },
4484 );
4485 agent_settings::AgentSettings::override_global(settings, cx);
4486 });
4487
4488 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
4489
4490 #[allow(clippy::arc_with_non_send_sync)]
4491 let tool = Arc::new(crate::FetchTool::new(http_client));
4492 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4493
4494 let input: crate::FetchToolInput =
4495 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
4496
4497 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
4498
4499 cx.run_until_parked();
4500
4501 let event = rx.try_next();
4502 assert!(
4503 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
4504 "expected no authorization request for allowed docs.rs URL"
4505 );
4506}