1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use fs::{FakeFs, Fs};
11use futures::{
12 FutureExt as _, StreamExt,
13 channel::{
14 mpsc::{self, UnboundedReceiver},
15 oneshot,
16 },
17 future::{Fuse, Shared},
18};
19use gpui::{
20 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
21 http_client::FakeHttpClient,
22};
23use indoc::indoc;
24use language_model::{
25 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
26 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
27 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
28 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
29};
30use pretty_assertions::assert_eq;
31use project::{
32 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
33};
34use prompt_store::ProjectContext;
35use reqwest_client::ReqwestClient;
36use schemars::JsonSchema;
37use serde::{Deserialize, Serialize};
38use serde_json::json;
39use settings::{Settings, SettingsStore};
40use std::{
41 path::Path,
42 pin::Pin,
43 rc::Rc,
44 sync::{
45 Arc,
46 atomic::{AtomicBool, Ordering},
47 },
48 time::Duration,
49};
50use util::path;
51
52mod test_tools;
53use test_tools::*;
54
55fn init_test(cx: &mut TestAppContext) {
56 cx.update(|cx| {
57 let settings_store = SettingsStore::test(cx);
58 cx.set_global(settings_store);
59 });
60}
61
62struct FakeTerminalHandle {
63 killed: Arc<AtomicBool>,
64 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
65 output: acp::TerminalOutputResponse,
66 id: acp::TerminalId,
67}
68
69impl FakeTerminalHandle {
70 fn new_never_exits(cx: &mut App) -> Self {
71 let killed = Arc::new(AtomicBool::new(false));
72
73 let killed_for_task = killed.clone();
74 let wait_for_exit = cx
75 .spawn(async move |cx| {
76 loop {
77 if killed_for_task.load(Ordering::SeqCst) {
78 return acp::TerminalExitStatus::new();
79 }
80 cx.background_executor()
81 .timer(Duration::from_millis(1))
82 .await;
83 }
84 })
85 .shared();
86
87 Self {
88 killed,
89 wait_for_exit,
90 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
91 id: acp::TerminalId::new("fake_terminal".to_string()),
92 }
93 }
94
95 fn was_killed(&self) -> bool {
96 self.killed.load(Ordering::SeqCst)
97 }
98}
99
100impl crate::TerminalHandle for FakeTerminalHandle {
101 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
102 Ok(self.id.clone())
103 }
104
105 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
106 Ok(self.output.clone())
107 }
108
109 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
110 Ok(self.wait_for_exit.clone())
111 }
112
113 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
114 self.killed.store(true, Ordering::SeqCst);
115 Ok(())
116 }
117}
118
119struct FakeThreadEnvironment {
120 handle: Rc<FakeTerminalHandle>,
121}
122
123impl crate::ThreadEnvironment for FakeThreadEnvironment {
124 fn create_terminal(
125 &self,
126 _command: String,
127 _cwd: Option<std::path::PathBuf>,
128 _output_byte_limit: Option<u64>,
129 _cx: &mut AsyncApp,
130 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
131 Task::ready(Ok(self.handle.clone() as Rc<dyn crate::TerminalHandle>))
132 }
133}
134
135fn always_allow_tools(cx: &mut TestAppContext) {
136 cx.update(|cx| {
137 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
138 settings.always_allow_tool_actions = true;
139 agent_settings::AgentSettings::override_global(settings, cx);
140 });
141}
142
143#[gpui::test]
144async fn test_echo(cx: &mut TestAppContext) {
145 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
146 let fake_model = model.as_fake();
147
148 let events = thread
149 .update(cx, |thread, cx| {
150 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
151 })
152 .unwrap();
153 cx.run_until_parked();
154 fake_model.send_last_completion_stream_text_chunk("Hello");
155 fake_model
156 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
157 fake_model.end_last_completion_stream();
158
159 let events = events.collect().await;
160 thread.update(cx, |thread, _cx| {
161 assert_eq!(
162 thread.last_message().unwrap().to_markdown(),
163 indoc! {"
164 ## Assistant
165
166 Hello
167 "}
168 )
169 });
170 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
171}
172
173#[gpui::test]
174async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
175 init_test(cx);
176 always_allow_tools(cx);
177
178 let fs = FakeFs::new(cx.executor());
179 let project = Project::test(fs, [], cx).await;
180
181 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
182 let environment = Rc::new(FakeThreadEnvironment {
183 handle: handle.clone(),
184 });
185
186 #[allow(clippy::arc_with_non_send_sync)]
187 let tool = Arc::new(crate::TerminalTool::new(project, environment));
188 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
189
190 let task = cx.update(|cx| {
191 tool.run(
192 crate::TerminalToolInput {
193 command: "sleep 1000".to_string(),
194 cd: ".".to_string(),
195 timeout_ms: Some(5),
196 },
197 event_stream,
198 cx,
199 )
200 });
201
202 let update = rx.expect_update_fields().await;
203 assert!(
204 update.content.iter().any(|blocks| {
205 blocks
206 .iter()
207 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
208 }),
209 "expected tool call update to include terminal content"
210 );
211
212 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
213
214 let deadline = std::time::Instant::now() + Duration::from_millis(500);
215 loop {
216 if let Some(result) = task_future.as_mut().now_or_never() {
217 let result = result.expect("terminal tool task should complete");
218
219 assert!(
220 handle.was_killed(),
221 "expected terminal handle to be killed on timeout"
222 );
223 assert!(
224 result.contains("partial output"),
225 "expected result to include terminal output, got: {result}"
226 );
227 return;
228 }
229
230 if std::time::Instant::now() >= deadline {
231 panic!("timed out waiting for terminal tool task to complete");
232 }
233
234 cx.run_until_parked();
235 cx.background_executor.timer(Duration::from_millis(1)).await;
236 }
237}
238
239#[gpui::test]
240#[ignore]
241async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
242 init_test(cx);
243 always_allow_tools(cx);
244
245 let fs = FakeFs::new(cx.executor());
246 let project = Project::test(fs, [], cx).await;
247
248 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
249 let environment = Rc::new(FakeThreadEnvironment {
250 handle: handle.clone(),
251 });
252
253 #[allow(clippy::arc_with_non_send_sync)]
254 let tool = Arc::new(crate::TerminalTool::new(project, environment));
255 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
256
257 let _task = cx.update(|cx| {
258 tool.run(
259 crate::TerminalToolInput {
260 command: "sleep 1000".to_string(),
261 cd: ".".to_string(),
262 timeout_ms: None,
263 },
264 event_stream,
265 cx,
266 )
267 });
268
269 let update = rx.expect_update_fields().await;
270 assert!(
271 update.content.iter().any(|blocks| {
272 blocks
273 .iter()
274 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
275 }),
276 "expected tool call update to include terminal content"
277 );
278
279 smol::Timer::after(Duration::from_millis(25)).await;
280
281 assert!(
282 !handle.was_killed(),
283 "did not expect terminal handle to be killed without a timeout"
284 );
285}
286
287#[gpui::test]
288async fn test_thinking(cx: &mut TestAppContext) {
289 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
290 let fake_model = model.as_fake();
291
292 let events = thread
293 .update(cx, |thread, cx| {
294 thread.send(
295 UserMessageId::new(),
296 [indoc! {"
297 Testing:
298
299 Generate a thinking step where you just think the word 'Think',
300 and have your final answer be 'Hello'
301 "}],
302 cx,
303 )
304 })
305 .unwrap();
306 cx.run_until_parked();
307 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
308 text: "Think".to_string(),
309 signature: None,
310 });
311 fake_model.send_last_completion_stream_text_chunk("Hello");
312 fake_model
313 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
314 fake_model.end_last_completion_stream();
315
316 let events = events.collect().await;
317 thread.update(cx, |thread, _cx| {
318 assert_eq!(
319 thread.last_message().unwrap().to_markdown(),
320 indoc! {"
321 ## Assistant
322
323 <think>Think</think>
324 Hello
325 "}
326 )
327 });
328 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
329}
330
331#[gpui::test]
332async fn test_system_prompt(cx: &mut TestAppContext) {
333 let ThreadTest {
334 model,
335 thread,
336 project_context,
337 ..
338 } = setup(cx, TestModel::Fake).await;
339 let fake_model = model.as_fake();
340
341 project_context.update(cx, |project_context, _cx| {
342 project_context.shell = "test-shell".into()
343 });
344 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
345 thread
346 .update(cx, |thread, cx| {
347 thread.send(UserMessageId::new(), ["abc"], cx)
348 })
349 .unwrap();
350 cx.run_until_parked();
351 let mut pending_completions = fake_model.pending_completions();
352 assert_eq!(
353 pending_completions.len(),
354 1,
355 "unexpected pending completions: {:?}",
356 pending_completions
357 );
358
359 let pending_completion = pending_completions.pop().unwrap();
360 assert_eq!(pending_completion.messages[0].role, Role::System);
361
362 let system_message = &pending_completion.messages[0];
363 let system_prompt = system_message.content[0].to_str().unwrap();
364 assert!(
365 system_prompt.contains("test-shell"),
366 "unexpected system message: {:?}",
367 system_message
368 );
369 assert!(
370 system_prompt.contains("## Fixing Diagnostics"),
371 "unexpected system message: {:?}",
372 system_message
373 );
374}
375
376#[gpui::test]
377async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
378 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
379 let fake_model = model.as_fake();
380
381 thread
382 .update(cx, |thread, cx| {
383 thread.send(UserMessageId::new(), ["abc"], cx)
384 })
385 .unwrap();
386 cx.run_until_parked();
387 let mut pending_completions = fake_model.pending_completions();
388 assert_eq!(
389 pending_completions.len(),
390 1,
391 "unexpected pending completions: {:?}",
392 pending_completions
393 );
394
395 let pending_completion = pending_completions.pop().unwrap();
396 assert_eq!(pending_completion.messages[0].role, Role::System);
397
398 let system_message = &pending_completion.messages[0];
399 let system_prompt = system_message.content[0].to_str().unwrap();
400 assert!(
401 !system_prompt.contains("## Tool Use"),
402 "unexpected system message: {:?}",
403 system_message
404 );
405 assert!(
406 !system_prompt.contains("## Fixing Diagnostics"),
407 "unexpected system message: {:?}",
408 system_message
409 );
410}
411
412#[gpui::test]
413async fn test_prompt_caching(cx: &mut TestAppContext) {
414 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
415 let fake_model = model.as_fake();
416
417 // Send initial user message and verify it's cached
418 thread
419 .update(cx, |thread, cx| {
420 thread.send(UserMessageId::new(), ["Message 1"], cx)
421 })
422 .unwrap();
423 cx.run_until_parked();
424
425 let completion = fake_model.pending_completions().pop().unwrap();
426 assert_eq!(
427 completion.messages[1..],
428 vec![LanguageModelRequestMessage {
429 role: Role::User,
430 content: vec!["Message 1".into()],
431 cache: true,
432 reasoning_details: None,
433 }]
434 );
435 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
436 "Response to Message 1".into(),
437 ));
438 fake_model.end_last_completion_stream();
439 cx.run_until_parked();
440
441 // Send another user message and verify only the latest is cached
442 thread
443 .update(cx, |thread, cx| {
444 thread.send(UserMessageId::new(), ["Message 2"], cx)
445 })
446 .unwrap();
447 cx.run_until_parked();
448
449 let completion = fake_model.pending_completions().pop().unwrap();
450 assert_eq!(
451 completion.messages[1..],
452 vec![
453 LanguageModelRequestMessage {
454 role: Role::User,
455 content: vec!["Message 1".into()],
456 cache: false,
457 reasoning_details: None,
458 },
459 LanguageModelRequestMessage {
460 role: Role::Assistant,
461 content: vec!["Response to Message 1".into()],
462 cache: false,
463 reasoning_details: None,
464 },
465 LanguageModelRequestMessage {
466 role: Role::User,
467 content: vec!["Message 2".into()],
468 cache: true,
469 reasoning_details: None,
470 }
471 ]
472 );
473 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
474 "Response to Message 2".into(),
475 ));
476 fake_model.end_last_completion_stream();
477 cx.run_until_parked();
478
479 // Simulate a tool call and verify that the latest tool result is cached
480 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
481 thread
482 .update(cx, |thread, cx| {
483 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
484 })
485 .unwrap();
486 cx.run_until_parked();
487
488 let tool_use = LanguageModelToolUse {
489 id: "tool_1".into(),
490 name: EchoTool::name().into(),
491 raw_input: json!({"text": "test"}).to_string(),
492 input: json!({"text": "test"}),
493 is_input_complete: true,
494 thought_signature: None,
495 };
496 fake_model
497 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
498 fake_model.end_last_completion_stream();
499 cx.run_until_parked();
500
501 let completion = fake_model.pending_completions().pop().unwrap();
502 let tool_result = LanguageModelToolResult {
503 tool_use_id: "tool_1".into(),
504 tool_name: EchoTool::name().into(),
505 is_error: false,
506 content: "test".into(),
507 output: Some("test".into()),
508 };
509 assert_eq!(
510 completion.messages[1..],
511 vec![
512 LanguageModelRequestMessage {
513 role: Role::User,
514 content: vec!["Message 1".into()],
515 cache: false,
516 reasoning_details: None,
517 },
518 LanguageModelRequestMessage {
519 role: Role::Assistant,
520 content: vec!["Response to Message 1".into()],
521 cache: false,
522 reasoning_details: None,
523 },
524 LanguageModelRequestMessage {
525 role: Role::User,
526 content: vec!["Message 2".into()],
527 cache: false,
528 reasoning_details: None,
529 },
530 LanguageModelRequestMessage {
531 role: Role::Assistant,
532 content: vec!["Response to Message 2".into()],
533 cache: false,
534 reasoning_details: None,
535 },
536 LanguageModelRequestMessage {
537 role: Role::User,
538 content: vec!["Use the echo tool".into()],
539 cache: false,
540 reasoning_details: None,
541 },
542 LanguageModelRequestMessage {
543 role: Role::Assistant,
544 content: vec![MessageContent::ToolUse(tool_use)],
545 cache: false,
546 reasoning_details: None,
547 },
548 LanguageModelRequestMessage {
549 role: Role::User,
550 content: vec![MessageContent::ToolResult(tool_result)],
551 cache: true,
552 reasoning_details: None,
553 }
554 ]
555 );
556}
557
558#[gpui::test]
559#[cfg_attr(not(feature = "e2e"), ignore)]
560async fn test_basic_tool_calls(cx: &mut TestAppContext) {
561 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
562
563 // Test a tool call that's likely to complete *before* streaming stops.
564 let events = thread
565 .update(cx, |thread, cx| {
566 thread.add_tool(EchoTool);
567 thread.send(
568 UserMessageId::new(),
569 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
570 cx,
571 )
572 })
573 .unwrap()
574 .collect()
575 .await;
576 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
577
578 // Test a tool calls that's likely to complete *after* streaming stops.
579 let events = thread
580 .update(cx, |thread, cx| {
581 thread.remove_tool(&EchoTool::name());
582 thread.add_tool(DelayTool);
583 thread.send(
584 UserMessageId::new(),
585 [
586 "Now call the delay tool with 200ms.",
587 "When the timer goes off, then you echo the output of the tool.",
588 ],
589 cx,
590 )
591 })
592 .unwrap()
593 .collect()
594 .await;
595 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
596 thread.update(cx, |thread, _cx| {
597 assert!(
598 thread
599 .last_message()
600 .unwrap()
601 .as_agent_message()
602 .unwrap()
603 .content
604 .iter()
605 .any(|content| {
606 if let AgentMessageContent::Text(text) = content {
607 text.contains("Ding")
608 } else {
609 false
610 }
611 }),
612 "{}",
613 thread.to_markdown()
614 );
615 });
616}
617
618#[gpui::test]
619#[cfg_attr(not(feature = "e2e"), ignore)]
620async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
621 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
622
623 // Test a tool call that's likely to complete *before* streaming stops.
624 let mut events = thread
625 .update(cx, |thread, cx| {
626 thread.add_tool(WordListTool);
627 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
628 })
629 .unwrap();
630
631 let mut saw_partial_tool_use = false;
632 while let Some(event) = events.next().await {
633 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
634 thread.update(cx, |thread, _cx| {
635 // Look for a tool use in the thread's last message
636 let message = thread.last_message().unwrap();
637 let agent_message = message.as_agent_message().unwrap();
638 let last_content = agent_message.content.last().unwrap();
639 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
640 assert_eq!(last_tool_use.name.as_ref(), "word_list");
641 if tool_call.status == acp::ToolCallStatus::Pending {
642 if !last_tool_use.is_input_complete
643 && last_tool_use.input.get("g").is_none()
644 {
645 saw_partial_tool_use = true;
646 }
647 } else {
648 last_tool_use
649 .input
650 .get("a")
651 .expect("'a' has streamed because input is now complete");
652 last_tool_use
653 .input
654 .get("g")
655 .expect("'g' has streamed because input is now complete");
656 }
657 } else {
658 panic!("last content should be a tool use");
659 }
660 });
661 }
662 }
663
664 assert!(
665 saw_partial_tool_use,
666 "should see at least one partially streamed tool use in the history"
667 );
668}
669
670#[gpui::test]
671async fn test_tool_authorization(cx: &mut TestAppContext) {
672 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
673 let fake_model = model.as_fake();
674
675 let mut events = thread
676 .update(cx, |thread, cx| {
677 thread.add_tool(ToolRequiringPermission);
678 thread.send(UserMessageId::new(), ["abc"], cx)
679 })
680 .unwrap();
681 cx.run_until_parked();
682 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
683 LanguageModelToolUse {
684 id: "tool_id_1".into(),
685 name: ToolRequiringPermission::name().into(),
686 raw_input: "{}".into(),
687 input: json!({}),
688 is_input_complete: true,
689 thought_signature: None,
690 },
691 ));
692 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
693 LanguageModelToolUse {
694 id: "tool_id_2".into(),
695 name: ToolRequiringPermission::name().into(),
696 raw_input: "{}".into(),
697 input: json!({}),
698 is_input_complete: true,
699 thought_signature: None,
700 },
701 ));
702 fake_model.end_last_completion_stream();
703 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
704 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
705
706 // Approve the first
707 tool_call_auth_1
708 .response
709 .send(tool_call_auth_1.options[1].option_id.clone())
710 .unwrap();
711 cx.run_until_parked();
712
713 // Reject the second
714 tool_call_auth_2
715 .response
716 .send(tool_call_auth_1.options[2].option_id.clone())
717 .unwrap();
718 cx.run_until_parked();
719
720 let completion = fake_model.pending_completions().pop().unwrap();
721 let message = completion.messages.last().unwrap();
722 assert_eq!(
723 message.content,
724 vec![
725 language_model::MessageContent::ToolResult(LanguageModelToolResult {
726 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
727 tool_name: ToolRequiringPermission::name().into(),
728 is_error: false,
729 content: "Allowed".into(),
730 output: Some("Allowed".into())
731 }),
732 language_model::MessageContent::ToolResult(LanguageModelToolResult {
733 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
734 tool_name: ToolRequiringPermission::name().into(),
735 is_error: true,
736 content: "Permission to run tool denied by user".into(),
737 output: Some("Permission to run tool denied by user".into())
738 })
739 ]
740 );
741
742 // Simulate yet another tool call.
743 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
744 LanguageModelToolUse {
745 id: "tool_id_3".into(),
746 name: ToolRequiringPermission::name().into(),
747 raw_input: "{}".into(),
748 input: json!({}),
749 is_input_complete: true,
750 thought_signature: None,
751 },
752 ));
753 fake_model.end_last_completion_stream();
754
755 // Respond by always allowing tools.
756 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
757 tool_call_auth_3
758 .response
759 .send(tool_call_auth_3.options[0].option_id.clone())
760 .unwrap();
761 cx.run_until_parked();
762 let completion = fake_model.pending_completions().pop().unwrap();
763 let message = completion.messages.last().unwrap();
764 assert_eq!(
765 message.content,
766 vec![language_model::MessageContent::ToolResult(
767 LanguageModelToolResult {
768 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
769 tool_name: ToolRequiringPermission::name().into(),
770 is_error: false,
771 content: "Allowed".into(),
772 output: Some("Allowed".into())
773 }
774 )]
775 );
776
777 // Simulate a final tool call, ensuring we don't trigger authorization.
778 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
779 LanguageModelToolUse {
780 id: "tool_id_4".into(),
781 name: ToolRequiringPermission::name().into(),
782 raw_input: "{}".into(),
783 input: json!({}),
784 is_input_complete: true,
785 thought_signature: None,
786 },
787 ));
788 fake_model.end_last_completion_stream();
789 cx.run_until_parked();
790 let completion = fake_model.pending_completions().pop().unwrap();
791 let message = completion.messages.last().unwrap();
792 assert_eq!(
793 message.content,
794 vec![language_model::MessageContent::ToolResult(
795 LanguageModelToolResult {
796 tool_use_id: "tool_id_4".into(),
797 tool_name: ToolRequiringPermission::name().into(),
798 is_error: false,
799 content: "Allowed".into(),
800 output: Some("Allowed".into())
801 }
802 )]
803 );
804}
805
806#[gpui::test]
807async fn test_tool_hallucination(cx: &mut TestAppContext) {
808 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
809 let fake_model = model.as_fake();
810
811 let mut events = thread
812 .update(cx, |thread, cx| {
813 thread.send(UserMessageId::new(), ["abc"], cx)
814 })
815 .unwrap();
816 cx.run_until_parked();
817 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
818 LanguageModelToolUse {
819 id: "tool_id_1".into(),
820 name: "nonexistent_tool".into(),
821 raw_input: "{}".into(),
822 input: json!({}),
823 is_input_complete: true,
824 thought_signature: None,
825 },
826 ));
827 fake_model.end_last_completion_stream();
828
829 let tool_call = expect_tool_call(&mut events).await;
830 assert_eq!(tool_call.title, "nonexistent_tool");
831 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
832 let update = expect_tool_call_update_fields(&mut events).await;
833 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
834}
835
836#[gpui::test]
837async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
838 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
839 let fake_model = model.as_fake();
840
841 let events = thread
842 .update(cx, |thread, cx| {
843 thread.add_tool(EchoTool);
844 thread.send(UserMessageId::new(), ["abc"], cx)
845 })
846 .unwrap();
847 cx.run_until_parked();
848 let tool_use = LanguageModelToolUse {
849 id: "tool_id_1".into(),
850 name: EchoTool::name().into(),
851 raw_input: "{}".into(),
852 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
853 is_input_complete: true,
854 thought_signature: None,
855 };
856 fake_model
857 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
858 fake_model.end_last_completion_stream();
859
860 cx.run_until_parked();
861 let completion = fake_model.pending_completions().pop().unwrap();
862 let tool_result = LanguageModelToolResult {
863 tool_use_id: "tool_id_1".into(),
864 tool_name: EchoTool::name().into(),
865 is_error: false,
866 content: "def".into(),
867 output: Some("def".into()),
868 };
869 assert_eq!(
870 completion.messages[1..],
871 vec![
872 LanguageModelRequestMessage {
873 role: Role::User,
874 content: vec!["abc".into()],
875 cache: false,
876 reasoning_details: None,
877 },
878 LanguageModelRequestMessage {
879 role: Role::Assistant,
880 content: vec![MessageContent::ToolUse(tool_use.clone())],
881 cache: false,
882 reasoning_details: None,
883 },
884 LanguageModelRequestMessage {
885 role: Role::User,
886 content: vec![MessageContent::ToolResult(tool_result.clone())],
887 cache: true,
888 reasoning_details: None,
889 },
890 ]
891 );
892
893 // Simulate reaching tool use limit.
894 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
895 fake_model.end_last_completion_stream();
896 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
897 assert!(
898 last_event
899 .unwrap_err()
900 .is::<language_model::ToolUseLimitReachedError>()
901 );
902
903 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
904 cx.run_until_parked();
905 let completion = fake_model.pending_completions().pop().unwrap();
906 assert_eq!(
907 completion.messages[1..],
908 vec![
909 LanguageModelRequestMessage {
910 role: Role::User,
911 content: vec!["abc".into()],
912 cache: false,
913 reasoning_details: None,
914 },
915 LanguageModelRequestMessage {
916 role: Role::Assistant,
917 content: vec![MessageContent::ToolUse(tool_use)],
918 cache: false,
919 reasoning_details: None,
920 },
921 LanguageModelRequestMessage {
922 role: Role::User,
923 content: vec![MessageContent::ToolResult(tool_result)],
924 cache: false,
925 reasoning_details: None,
926 },
927 LanguageModelRequestMessage {
928 role: Role::User,
929 content: vec!["Continue where you left off".into()],
930 cache: true,
931 reasoning_details: None,
932 }
933 ]
934 );
935
936 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
937 fake_model.end_last_completion_stream();
938 events.collect::<Vec<_>>().await;
939 thread.read_with(cx, |thread, _cx| {
940 assert_eq!(
941 thread.last_message().unwrap().to_markdown(),
942 indoc! {"
943 ## Assistant
944
945 Done
946 "}
947 )
948 });
949}
950
951#[gpui::test]
952async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
953 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
954 let fake_model = model.as_fake();
955
956 let events = thread
957 .update(cx, |thread, cx| {
958 thread.add_tool(EchoTool);
959 thread.send(UserMessageId::new(), ["abc"], cx)
960 })
961 .unwrap();
962 cx.run_until_parked();
963
964 let tool_use = LanguageModelToolUse {
965 id: "tool_id_1".into(),
966 name: EchoTool::name().into(),
967 raw_input: "{}".into(),
968 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
969 is_input_complete: true,
970 thought_signature: None,
971 };
972 let tool_result = LanguageModelToolResult {
973 tool_use_id: "tool_id_1".into(),
974 tool_name: EchoTool::name().into(),
975 is_error: false,
976 content: "def".into(),
977 output: Some("def".into()),
978 };
979 fake_model
980 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
981 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
982 fake_model.end_last_completion_stream();
983 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
984 assert!(
985 last_event
986 .unwrap_err()
987 .is::<language_model::ToolUseLimitReachedError>()
988 );
989
990 thread
991 .update(cx, |thread, cx| {
992 thread.send(UserMessageId::new(), vec!["ghi"], cx)
993 })
994 .unwrap();
995 cx.run_until_parked();
996 let completion = fake_model.pending_completions().pop().unwrap();
997 assert_eq!(
998 completion.messages[1..],
999 vec![
1000 LanguageModelRequestMessage {
1001 role: Role::User,
1002 content: vec!["abc".into()],
1003 cache: false,
1004 reasoning_details: None,
1005 },
1006 LanguageModelRequestMessage {
1007 role: Role::Assistant,
1008 content: vec![MessageContent::ToolUse(tool_use)],
1009 cache: false,
1010 reasoning_details: None,
1011 },
1012 LanguageModelRequestMessage {
1013 role: Role::User,
1014 content: vec![MessageContent::ToolResult(tool_result)],
1015 cache: false,
1016 reasoning_details: None,
1017 },
1018 LanguageModelRequestMessage {
1019 role: Role::User,
1020 content: vec!["ghi".into()],
1021 cache: true,
1022 reasoning_details: None,
1023 }
1024 ]
1025 );
1026}
1027
1028async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
1029 let event = events
1030 .next()
1031 .await
1032 .expect("no tool call authorization event received")
1033 .unwrap();
1034 match event {
1035 ThreadEvent::ToolCall(tool_call) => tool_call,
1036 event => {
1037 panic!("Unexpected event {event:?}");
1038 }
1039 }
1040}
1041
1042async fn expect_tool_call_update_fields(
1043 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1044) -> acp::ToolCallUpdate {
1045 let event = events
1046 .next()
1047 .await
1048 .expect("no tool call authorization event received")
1049 .unwrap();
1050 match event {
1051 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
1052 event => {
1053 panic!("Unexpected event {event:?}");
1054 }
1055 }
1056}
1057
1058async fn next_tool_call_authorization(
1059 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1060) -> ToolCallAuthorization {
1061 loop {
1062 let event = events
1063 .next()
1064 .await
1065 .expect("no tool call authorization event received")
1066 .unwrap();
1067 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1068 let permission_kinds = tool_call_authorization
1069 .options
1070 .iter()
1071 .map(|o| o.kind)
1072 .collect::<Vec<_>>();
1073 assert_eq!(
1074 permission_kinds,
1075 vec![
1076 acp::PermissionOptionKind::AllowAlways,
1077 acp::PermissionOptionKind::AllowOnce,
1078 acp::PermissionOptionKind::RejectOnce,
1079 ]
1080 );
1081 return tool_call_authorization;
1082 }
1083 }
1084}
1085
1086#[gpui::test]
1087#[cfg_attr(not(feature = "e2e"), ignore)]
1088async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1089 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1090
1091 // Test concurrent tool calls with different delay times
1092 let events = thread
1093 .update(cx, |thread, cx| {
1094 thread.add_tool(DelayTool);
1095 thread.send(
1096 UserMessageId::new(),
1097 [
1098 "Call the delay tool twice in the same message.",
1099 "Once with 100ms. Once with 300ms.",
1100 "When both timers are complete, describe the outputs.",
1101 ],
1102 cx,
1103 )
1104 })
1105 .unwrap()
1106 .collect()
1107 .await;
1108
1109 let stop_reasons = stop_events(events);
1110 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1111
1112 thread.update(cx, |thread, _cx| {
1113 let last_message = thread.last_message().unwrap();
1114 let agent_message = last_message.as_agent_message().unwrap();
1115 let text = agent_message
1116 .content
1117 .iter()
1118 .filter_map(|content| {
1119 if let AgentMessageContent::Text(text) = content {
1120 Some(text.as_str())
1121 } else {
1122 None
1123 }
1124 })
1125 .collect::<String>();
1126
1127 assert!(text.contains("Ding"));
1128 });
1129}
1130
1131#[gpui::test]
1132async fn test_profiles(cx: &mut TestAppContext) {
1133 let ThreadTest {
1134 model, thread, fs, ..
1135 } = setup(cx, TestModel::Fake).await;
1136 let fake_model = model.as_fake();
1137
1138 thread.update(cx, |thread, _cx| {
1139 thread.add_tool(DelayTool);
1140 thread.add_tool(EchoTool);
1141 thread.add_tool(InfiniteTool);
1142 });
1143
1144 // Override profiles and wait for settings to be loaded.
1145 fs.insert_file(
1146 paths::settings_file(),
1147 json!({
1148 "agent": {
1149 "profiles": {
1150 "test-1": {
1151 "name": "Test Profile 1",
1152 "tools": {
1153 EchoTool::name(): true,
1154 DelayTool::name(): true,
1155 }
1156 },
1157 "test-2": {
1158 "name": "Test Profile 2",
1159 "tools": {
1160 InfiniteTool::name(): true,
1161 }
1162 }
1163 }
1164 }
1165 })
1166 .to_string()
1167 .into_bytes(),
1168 )
1169 .await;
1170 cx.run_until_parked();
1171
1172 // Test that test-1 profile (default) has echo and delay tools
1173 thread
1174 .update(cx, |thread, cx| {
1175 thread.set_profile(AgentProfileId("test-1".into()), cx);
1176 thread.send(UserMessageId::new(), ["test"], cx)
1177 })
1178 .unwrap();
1179 cx.run_until_parked();
1180
1181 let mut pending_completions = fake_model.pending_completions();
1182 assert_eq!(pending_completions.len(), 1);
1183 let completion = pending_completions.pop().unwrap();
1184 let tool_names: Vec<String> = completion
1185 .tools
1186 .iter()
1187 .map(|tool| tool.name.clone())
1188 .collect();
1189 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
1190 fake_model.end_last_completion_stream();
1191
1192 // Switch to test-2 profile, and verify that it has only the infinite tool.
1193 thread
1194 .update(cx, |thread, cx| {
1195 thread.set_profile(AgentProfileId("test-2".into()), cx);
1196 thread.send(UserMessageId::new(), ["test2"], cx)
1197 })
1198 .unwrap();
1199 cx.run_until_parked();
1200 let mut pending_completions = fake_model.pending_completions();
1201 assert_eq!(pending_completions.len(), 1);
1202 let completion = pending_completions.pop().unwrap();
1203 let tool_names: Vec<String> = completion
1204 .tools
1205 .iter()
1206 .map(|tool| tool.name.clone())
1207 .collect();
1208 assert_eq!(tool_names, vec![InfiniteTool::name()]);
1209}
1210
1211#[gpui::test]
1212async fn test_mcp_tools(cx: &mut TestAppContext) {
1213 let ThreadTest {
1214 model,
1215 thread,
1216 context_server_store,
1217 fs,
1218 ..
1219 } = setup(cx, TestModel::Fake).await;
1220 let fake_model = model.as_fake();
1221
1222 // Override profiles and wait for settings to be loaded.
1223 fs.insert_file(
1224 paths::settings_file(),
1225 json!({
1226 "agent": {
1227 "always_allow_tool_actions": true,
1228 "profiles": {
1229 "test": {
1230 "name": "Test Profile",
1231 "enable_all_context_servers": true,
1232 "tools": {
1233 EchoTool::name(): true,
1234 }
1235 },
1236 }
1237 }
1238 })
1239 .to_string()
1240 .into_bytes(),
1241 )
1242 .await;
1243 cx.run_until_parked();
1244 thread.update(cx, |thread, cx| {
1245 thread.set_profile(AgentProfileId("test".into()), cx)
1246 });
1247
1248 let mut mcp_tool_calls = setup_context_server(
1249 "test_server",
1250 vec![context_server::types::Tool {
1251 name: "echo".into(),
1252 description: None,
1253 input_schema: serde_json::to_value(EchoTool::input_schema(
1254 LanguageModelToolSchemaFormat::JsonSchema,
1255 ))
1256 .unwrap(),
1257 output_schema: None,
1258 annotations: None,
1259 }],
1260 &context_server_store,
1261 cx,
1262 );
1263
1264 let events = thread.update(cx, |thread, cx| {
1265 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1266 });
1267 cx.run_until_parked();
1268
1269 // Simulate the model calling the MCP tool.
1270 let completion = fake_model.pending_completions().pop().unwrap();
1271 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1272 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1273 LanguageModelToolUse {
1274 id: "tool_1".into(),
1275 name: "echo".into(),
1276 raw_input: json!({"text": "test"}).to_string(),
1277 input: json!({"text": "test"}),
1278 is_input_complete: true,
1279 thought_signature: None,
1280 },
1281 ));
1282 fake_model.end_last_completion_stream();
1283 cx.run_until_parked();
1284
1285 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1286 assert_eq!(tool_call_params.name, "echo");
1287 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1288 tool_call_response
1289 .send(context_server::types::CallToolResponse {
1290 content: vec![context_server::types::ToolResponseContent::Text {
1291 text: "test".into(),
1292 }],
1293 is_error: None,
1294 meta: None,
1295 structured_content: None,
1296 })
1297 .unwrap();
1298 cx.run_until_parked();
1299
1300 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1301 fake_model.send_last_completion_stream_text_chunk("Done!");
1302 fake_model.end_last_completion_stream();
1303 events.collect::<Vec<_>>().await;
1304
1305 // Send again after adding the echo tool, ensuring the name collision is resolved.
1306 let events = thread.update(cx, |thread, cx| {
1307 thread.add_tool(EchoTool);
1308 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1309 });
1310 cx.run_until_parked();
1311 let completion = fake_model.pending_completions().pop().unwrap();
1312 assert_eq!(
1313 tool_names_for_completion(&completion),
1314 vec!["echo", "test_server_echo"]
1315 );
1316 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1317 LanguageModelToolUse {
1318 id: "tool_2".into(),
1319 name: "test_server_echo".into(),
1320 raw_input: json!({"text": "mcp"}).to_string(),
1321 input: json!({"text": "mcp"}),
1322 is_input_complete: true,
1323 thought_signature: None,
1324 },
1325 ));
1326 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1327 LanguageModelToolUse {
1328 id: "tool_3".into(),
1329 name: "echo".into(),
1330 raw_input: json!({"text": "native"}).to_string(),
1331 input: json!({"text": "native"}),
1332 is_input_complete: true,
1333 thought_signature: None,
1334 },
1335 ));
1336 fake_model.end_last_completion_stream();
1337 cx.run_until_parked();
1338
1339 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1340 assert_eq!(tool_call_params.name, "echo");
1341 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1342 tool_call_response
1343 .send(context_server::types::CallToolResponse {
1344 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1345 is_error: None,
1346 meta: None,
1347 structured_content: None,
1348 })
1349 .unwrap();
1350 cx.run_until_parked();
1351
1352 // Ensure the tool results were inserted with the correct names.
1353 let completion = fake_model.pending_completions().pop().unwrap();
1354 assert_eq!(
1355 completion.messages.last().unwrap().content,
1356 vec![
1357 MessageContent::ToolResult(LanguageModelToolResult {
1358 tool_use_id: "tool_3".into(),
1359 tool_name: "echo".into(),
1360 is_error: false,
1361 content: "native".into(),
1362 output: Some("native".into()),
1363 },),
1364 MessageContent::ToolResult(LanguageModelToolResult {
1365 tool_use_id: "tool_2".into(),
1366 tool_name: "test_server_echo".into(),
1367 is_error: false,
1368 content: "mcp".into(),
1369 output: Some("mcp".into()),
1370 },),
1371 ]
1372 );
1373 fake_model.end_last_completion_stream();
1374 events.collect::<Vec<_>>().await;
1375}
1376
1377#[gpui::test]
1378async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1379 let ThreadTest {
1380 model,
1381 thread,
1382 context_server_store,
1383 fs,
1384 ..
1385 } = setup(cx, TestModel::Fake).await;
1386 let fake_model = model.as_fake();
1387
1388 // Set up a profile with all tools enabled
1389 fs.insert_file(
1390 paths::settings_file(),
1391 json!({
1392 "agent": {
1393 "profiles": {
1394 "test": {
1395 "name": "Test Profile",
1396 "enable_all_context_servers": true,
1397 "tools": {
1398 EchoTool::name(): true,
1399 DelayTool::name(): true,
1400 WordListTool::name(): true,
1401 ToolRequiringPermission::name(): true,
1402 InfiniteTool::name(): true,
1403 }
1404 },
1405 }
1406 }
1407 })
1408 .to_string()
1409 .into_bytes(),
1410 )
1411 .await;
1412 cx.run_until_parked();
1413
1414 thread.update(cx, |thread, cx| {
1415 thread.set_profile(AgentProfileId("test".into()), cx);
1416 thread.add_tool(EchoTool);
1417 thread.add_tool(DelayTool);
1418 thread.add_tool(WordListTool);
1419 thread.add_tool(ToolRequiringPermission);
1420 thread.add_tool(InfiniteTool);
1421 });
1422
1423 // Set up multiple context servers with some overlapping tool names
1424 let _server1_calls = setup_context_server(
1425 "xxx",
1426 vec![
1427 context_server::types::Tool {
1428 name: "echo".into(), // Conflicts with native EchoTool
1429 description: None,
1430 input_schema: serde_json::to_value(EchoTool::input_schema(
1431 LanguageModelToolSchemaFormat::JsonSchema,
1432 ))
1433 .unwrap(),
1434 output_schema: None,
1435 annotations: None,
1436 },
1437 context_server::types::Tool {
1438 name: "unique_tool_1".into(),
1439 description: None,
1440 input_schema: json!({"type": "object", "properties": {}}),
1441 output_schema: None,
1442 annotations: None,
1443 },
1444 ],
1445 &context_server_store,
1446 cx,
1447 );
1448
1449 let _server2_calls = setup_context_server(
1450 "yyy",
1451 vec![
1452 context_server::types::Tool {
1453 name: "echo".into(), // Also conflicts with native EchoTool
1454 description: None,
1455 input_schema: serde_json::to_value(EchoTool::input_schema(
1456 LanguageModelToolSchemaFormat::JsonSchema,
1457 ))
1458 .unwrap(),
1459 output_schema: None,
1460 annotations: None,
1461 },
1462 context_server::types::Tool {
1463 name: "unique_tool_2".into(),
1464 description: None,
1465 input_schema: json!({"type": "object", "properties": {}}),
1466 output_schema: None,
1467 annotations: None,
1468 },
1469 context_server::types::Tool {
1470 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1471 description: None,
1472 input_schema: json!({"type": "object", "properties": {}}),
1473 output_schema: None,
1474 annotations: None,
1475 },
1476 context_server::types::Tool {
1477 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1478 description: None,
1479 input_schema: json!({"type": "object", "properties": {}}),
1480 output_schema: None,
1481 annotations: None,
1482 },
1483 ],
1484 &context_server_store,
1485 cx,
1486 );
1487 let _server3_calls = setup_context_server(
1488 "zzz",
1489 vec![
1490 context_server::types::Tool {
1491 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1492 description: None,
1493 input_schema: json!({"type": "object", "properties": {}}),
1494 output_schema: None,
1495 annotations: None,
1496 },
1497 context_server::types::Tool {
1498 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1499 description: None,
1500 input_schema: json!({"type": "object", "properties": {}}),
1501 output_schema: None,
1502 annotations: None,
1503 },
1504 context_server::types::Tool {
1505 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1506 description: None,
1507 input_schema: json!({"type": "object", "properties": {}}),
1508 output_schema: None,
1509 annotations: None,
1510 },
1511 ],
1512 &context_server_store,
1513 cx,
1514 );
1515
1516 thread
1517 .update(cx, |thread, cx| {
1518 thread.send(UserMessageId::new(), ["Go"], cx)
1519 })
1520 .unwrap();
1521 cx.run_until_parked();
1522 let completion = fake_model.pending_completions().pop().unwrap();
1523 assert_eq!(
1524 tool_names_for_completion(&completion),
1525 vec![
1526 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1527 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1528 "delay",
1529 "echo",
1530 "infinite",
1531 "tool_requiring_permission",
1532 "unique_tool_1",
1533 "unique_tool_2",
1534 "word_list",
1535 "xxx_echo",
1536 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1537 "yyy_echo",
1538 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1539 ]
1540 );
1541}
1542
1543#[gpui::test]
1544#[cfg_attr(not(feature = "e2e"), ignore)]
1545async fn test_cancellation(cx: &mut TestAppContext) {
1546 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1547
1548 let mut events = thread
1549 .update(cx, |thread, cx| {
1550 thread.add_tool(InfiniteTool);
1551 thread.add_tool(EchoTool);
1552 thread.send(
1553 UserMessageId::new(),
1554 ["Call the echo tool, then call the infinite tool, then explain their output"],
1555 cx,
1556 )
1557 })
1558 .unwrap();
1559
1560 // Wait until both tools are called.
1561 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1562 let mut echo_id = None;
1563 let mut echo_completed = false;
1564 while let Some(event) = events.next().await {
1565 match event.unwrap() {
1566 ThreadEvent::ToolCall(tool_call) => {
1567 assert_eq!(tool_call.title, expected_tools.remove(0));
1568 if tool_call.title == "Echo" {
1569 echo_id = Some(tool_call.tool_call_id);
1570 }
1571 }
1572 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1573 acp::ToolCallUpdate {
1574 tool_call_id,
1575 fields:
1576 acp::ToolCallUpdateFields {
1577 status: Some(acp::ToolCallStatus::Completed),
1578 ..
1579 },
1580 ..
1581 },
1582 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1583 echo_completed = true;
1584 }
1585 _ => {}
1586 }
1587
1588 if expected_tools.is_empty() && echo_completed {
1589 break;
1590 }
1591 }
1592
1593 // Cancel the current send and ensure that the event stream is closed, even
1594 // if one of the tools is still running.
1595 thread.update(cx, |thread, cx| thread.cancel(cx));
1596 let events = events.collect::<Vec<_>>().await;
1597 let last_event = events.last();
1598 assert!(
1599 matches!(
1600 last_event,
1601 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1602 ),
1603 "unexpected event {last_event:?}"
1604 );
1605
1606 // Ensure we can still send a new message after cancellation.
1607 let events = thread
1608 .update(cx, |thread, cx| {
1609 thread.send(
1610 UserMessageId::new(),
1611 ["Testing: reply with 'Hello' then stop."],
1612 cx,
1613 )
1614 })
1615 .unwrap()
1616 .collect::<Vec<_>>()
1617 .await;
1618 thread.update(cx, |thread, _cx| {
1619 let message = thread.last_message().unwrap();
1620 let agent_message = message.as_agent_message().unwrap();
1621 assert_eq!(
1622 agent_message.content,
1623 vec![AgentMessageContent::Text("Hello".to_string())]
1624 );
1625 });
1626 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1627}
1628
1629#[gpui::test]
1630async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1631 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1632 let fake_model = model.as_fake();
1633
1634 let events_1 = thread
1635 .update(cx, |thread, cx| {
1636 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1637 })
1638 .unwrap();
1639 cx.run_until_parked();
1640 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1641 cx.run_until_parked();
1642
1643 let events_2 = thread
1644 .update(cx, |thread, cx| {
1645 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1646 })
1647 .unwrap();
1648 cx.run_until_parked();
1649 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1650 fake_model
1651 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1652 fake_model.end_last_completion_stream();
1653
1654 let events_1 = events_1.collect::<Vec<_>>().await;
1655 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1656 let events_2 = events_2.collect::<Vec<_>>().await;
1657 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1658}
1659
1660#[gpui::test]
1661async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1662 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1663 let fake_model = model.as_fake();
1664
1665 let events_1 = thread
1666 .update(cx, |thread, cx| {
1667 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1668 })
1669 .unwrap();
1670 cx.run_until_parked();
1671 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1672 fake_model
1673 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1674 fake_model.end_last_completion_stream();
1675 let events_1 = events_1.collect::<Vec<_>>().await;
1676
1677 let events_2 = thread
1678 .update(cx, |thread, cx| {
1679 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1680 })
1681 .unwrap();
1682 cx.run_until_parked();
1683 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1684 fake_model
1685 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1686 fake_model.end_last_completion_stream();
1687 let events_2 = events_2.collect::<Vec<_>>().await;
1688
1689 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1690 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1691}
1692
1693#[gpui::test]
1694async fn test_refusal(cx: &mut TestAppContext) {
1695 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1696 let fake_model = model.as_fake();
1697
1698 let events = thread
1699 .update(cx, |thread, cx| {
1700 thread.send(UserMessageId::new(), ["Hello"], cx)
1701 })
1702 .unwrap();
1703 cx.run_until_parked();
1704 thread.read_with(cx, |thread, _| {
1705 assert_eq!(
1706 thread.to_markdown(),
1707 indoc! {"
1708 ## User
1709
1710 Hello
1711 "}
1712 );
1713 });
1714
1715 fake_model.send_last_completion_stream_text_chunk("Hey!");
1716 cx.run_until_parked();
1717 thread.read_with(cx, |thread, _| {
1718 assert_eq!(
1719 thread.to_markdown(),
1720 indoc! {"
1721 ## User
1722
1723 Hello
1724
1725 ## Assistant
1726
1727 Hey!
1728 "}
1729 );
1730 });
1731
1732 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1733 fake_model
1734 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1735 let events = events.collect::<Vec<_>>().await;
1736 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1737 thread.read_with(cx, |thread, _| {
1738 assert_eq!(thread.to_markdown(), "");
1739 });
1740}
1741
1742#[gpui::test]
1743async fn test_truncate_first_message(cx: &mut TestAppContext) {
1744 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1745 let fake_model = model.as_fake();
1746
1747 let message_id = UserMessageId::new();
1748 thread
1749 .update(cx, |thread, cx| {
1750 thread.send(message_id.clone(), ["Hello"], cx)
1751 })
1752 .unwrap();
1753 cx.run_until_parked();
1754 thread.read_with(cx, |thread, _| {
1755 assert_eq!(
1756 thread.to_markdown(),
1757 indoc! {"
1758 ## User
1759
1760 Hello
1761 "}
1762 );
1763 assert_eq!(thread.latest_token_usage(), None);
1764 });
1765
1766 fake_model.send_last_completion_stream_text_chunk("Hey!");
1767 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1768 language_model::TokenUsage {
1769 input_tokens: 32_000,
1770 output_tokens: 16_000,
1771 cache_creation_input_tokens: 0,
1772 cache_read_input_tokens: 0,
1773 },
1774 ));
1775 cx.run_until_parked();
1776 thread.read_with(cx, |thread, _| {
1777 assert_eq!(
1778 thread.to_markdown(),
1779 indoc! {"
1780 ## User
1781
1782 Hello
1783
1784 ## Assistant
1785
1786 Hey!
1787 "}
1788 );
1789 assert_eq!(
1790 thread.latest_token_usage(),
1791 Some(acp_thread::TokenUsage {
1792 used_tokens: 32_000 + 16_000,
1793 max_tokens: 1_000_000,
1794 })
1795 );
1796 });
1797
1798 thread
1799 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1800 .unwrap();
1801 cx.run_until_parked();
1802 thread.read_with(cx, |thread, _| {
1803 assert_eq!(thread.to_markdown(), "");
1804 assert_eq!(thread.latest_token_usage(), None);
1805 });
1806
1807 // Ensure we can still send a new message after truncation.
1808 thread
1809 .update(cx, |thread, cx| {
1810 thread.send(UserMessageId::new(), ["Hi"], cx)
1811 })
1812 .unwrap();
1813 thread.update(cx, |thread, _cx| {
1814 assert_eq!(
1815 thread.to_markdown(),
1816 indoc! {"
1817 ## User
1818
1819 Hi
1820 "}
1821 );
1822 });
1823 cx.run_until_parked();
1824 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1825 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1826 language_model::TokenUsage {
1827 input_tokens: 40_000,
1828 output_tokens: 20_000,
1829 cache_creation_input_tokens: 0,
1830 cache_read_input_tokens: 0,
1831 },
1832 ));
1833 cx.run_until_parked();
1834 thread.read_with(cx, |thread, _| {
1835 assert_eq!(
1836 thread.to_markdown(),
1837 indoc! {"
1838 ## User
1839
1840 Hi
1841
1842 ## Assistant
1843
1844 Ahoy!
1845 "}
1846 );
1847
1848 assert_eq!(
1849 thread.latest_token_usage(),
1850 Some(acp_thread::TokenUsage {
1851 used_tokens: 40_000 + 20_000,
1852 max_tokens: 1_000_000,
1853 })
1854 );
1855 });
1856}
1857
1858#[gpui::test]
1859async fn test_truncate_second_message(cx: &mut TestAppContext) {
1860 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1861 let fake_model = model.as_fake();
1862
1863 thread
1864 .update(cx, |thread, cx| {
1865 thread.send(UserMessageId::new(), ["Message 1"], cx)
1866 })
1867 .unwrap();
1868 cx.run_until_parked();
1869 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1870 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1871 language_model::TokenUsage {
1872 input_tokens: 32_000,
1873 output_tokens: 16_000,
1874 cache_creation_input_tokens: 0,
1875 cache_read_input_tokens: 0,
1876 },
1877 ));
1878 fake_model.end_last_completion_stream();
1879 cx.run_until_parked();
1880
1881 let assert_first_message_state = |cx: &mut TestAppContext| {
1882 thread.clone().read_with(cx, |thread, _| {
1883 assert_eq!(
1884 thread.to_markdown(),
1885 indoc! {"
1886 ## User
1887
1888 Message 1
1889
1890 ## Assistant
1891
1892 Message 1 response
1893 "}
1894 );
1895
1896 assert_eq!(
1897 thread.latest_token_usage(),
1898 Some(acp_thread::TokenUsage {
1899 used_tokens: 32_000 + 16_000,
1900 max_tokens: 1_000_000,
1901 })
1902 );
1903 });
1904 };
1905
1906 assert_first_message_state(cx);
1907
1908 let second_message_id = UserMessageId::new();
1909 thread
1910 .update(cx, |thread, cx| {
1911 thread.send(second_message_id.clone(), ["Message 2"], cx)
1912 })
1913 .unwrap();
1914 cx.run_until_parked();
1915
1916 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1917 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1918 language_model::TokenUsage {
1919 input_tokens: 40_000,
1920 output_tokens: 20_000,
1921 cache_creation_input_tokens: 0,
1922 cache_read_input_tokens: 0,
1923 },
1924 ));
1925 fake_model.end_last_completion_stream();
1926 cx.run_until_parked();
1927
1928 thread.read_with(cx, |thread, _| {
1929 assert_eq!(
1930 thread.to_markdown(),
1931 indoc! {"
1932 ## User
1933
1934 Message 1
1935
1936 ## Assistant
1937
1938 Message 1 response
1939
1940 ## User
1941
1942 Message 2
1943
1944 ## Assistant
1945
1946 Message 2 response
1947 "}
1948 );
1949
1950 assert_eq!(
1951 thread.latest_token_usage(),
1952 Some(acp_thread::TokenUsage {
1953 used_tokens: 40_000 + 20_000,
1954 max_tokens: 1_000_000,
1955 })
1956 );
1957 });
1958
1959 thread
1960 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1961 .unwrap();
1962 cx.run_until_parked();
1963
1964 assert_first_message_state(cx);
1965}
1966
1967#[gpui::test]
1968async fn test_title_generation(cx: &mut TestAppContext) {
1969 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1970 let fake_model = model.as_fake();
1971
1972 let summary_model = Arc::new(FakeLanguageModel::default());
1973 thread.update(cx, |thread, cx| {
1974 thread.set_summarization_model(Some(summary_model.clone()), cx)
1975 });
1976
1977 let send = thread
1978 .update(cx, |thread, cx| {
1979 thread.send(UserMessageId::new(), ["Hello"], cx)
1980 })
1981 .unwrap();
1982 cx.run_until_parked();
1983
1984 fake_model.send_last_completion_stream_text_chunk("Hey!");
1985 fake_model.end_last_completion_stream();
1986 cx.run_until_parked();
1987 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1988
1989 // Ensure the summary model has been invoked to generate a title.
1990 summary_model.send_last_completion_stream_text_chunk("Hello ");
1991 summary_model.send_last_completion_stream_text_chunk("world\nG");
1992 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1993 summary_model.end_last_completion_stream();
1994 send.collect::<Vec<_>>().await;
1995 cx.run_until_parked();
1996 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1997
1998 // Send another message, ensuring no title is generated this time.
1999 let send = thread
2000 .update(cx, |thread, cx| {
2001 thread.send(UserMessageId::new(), ["Hello again"], cx)
2002 })
2003 .unwrap();
2004 cx.run_until_parked();
2005 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2006 fake_model.end_last_completion_stream();
2007 cx.run_until_parked();
2008 assert_eq!(summary_model.pending_completions(), Vec::new());
2009 send.collect::<Vec<_>>().await;
2010 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2011}
2012
2013#[gpui::test]
2014async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2015 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2016 let fake_model = model.as_fake();
2017
2018 let _events = thread
2019 .update(cx, |thread, cx| {
2020 thread.add_tool(ToolRequiringPermission);
2021 thread.add_tool(EchoTool);
2022 thread.send(UserMessageId::new(), ["Hey!"], cx)
2023 })
2024 .unwrap();
2025 cx.run_until_parked();
2026
2027 let permission_tool_use = LanguageModelToolUse {
2028 id: "tool_id_1".into(),
2029 name: ToolRequiringPermission::name().into(),
2030 raw_input: "{}".into(),
2031 input: json!({}),
2032 is_input_complete: true,
2033 thought_signature: None,
2034 };
2035 let echo_tool_use = LanguageModelToolUse {
2036 id: "tool_id_2".into(),
2037 name: EchoTool::name().into(),
2038 raw_input: json!({"text": "test"}).to_string(),
2039 input: json!({"text": "test"}),
2040 is_input_complete: true,
2041 thought_signature: None,
2042 };
2043 fake_model.send_last_completion_stream_text_chunk("Hi!");
2044 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2045 permission_tool_use,
2046 ));
2047 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2048 echo_tool_use.clone(),
2049 ));
2050 fake_model.end_last_completion_stream();
2051 cx.run_until_parked();
2052
2053 // Ensure pending tools are skipped when building a request.
2054 let request = thread
2055 .read_with(cx, |thread, cx| {
2056 thread.build_completion_request(CompletionIntent::EditFile, cx)
2057 })
2058 .unwrap();
2059 assert_eq!(
2060 request.messages[1..],
2061 vec![
2062 LanguageModelRequestMessage {
2063 role: Role::User,
2064 content: vec!["Hey!".into()],
2065 cache: true,
2066 reasoning_details: None,
2067 },
2068 LanguageModelRequestMessage {
2069 role: Role::Assistant,
2070 content: vec![
2071 MessageContent::Text("Hi!".into()),
2072 MessageContent::ToolUse(echo_tool_use.clone())
2073 ],
2074 cache: false,
2075 reasoning_details: None,
2076 },
2077 LanguageModelRequestMessage {
2078 role: Role::User,
2079 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2080 tool_use_id: echo_tool_use.id.clone(),
2081 tool_name: echo_tool_use.name,
2082 is_error: false,
2083 content: "test".into(),
2084 output: Some("test".into())
2085 })],
2086 cache: false,
2087 reasoning_details: None,
2088 },
2089 ],
2090 );
2091}
2092
2093#[gpui::test]
2094async fn test_agent_connection(cx: &mut TestAppContext) {
2095 cx.update(settings::init);
2096 let templates = Templates::new();
2097
2098 // Initialize language model system with test provider
2099 cx.update(|cx| {
2100 gpui_tokio::init(cx);
2101
2102 let http_client = FakeHttpClient::with_404_response();
2103 let clock = Arc::new(clock::FakeSystemClock::new());
2104 let client = Client::new(clock, http_client, cx);
2105 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2106 language_model::init(client.clone(), cx);
2107 language_models::init(user_store, client.clone(), cx);
2108 LanguageModelRegistry::test(cx);
2109 });
2110 cx.executor().forbid_parking();
2111
2112 // Create a project for new_thread
2113 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2114 fake_fs.insert_tree(path!("/test"), json!({})).await;
2115 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2116 let cwd = Path::new("/test");
2117 let text_thread_store =
2118 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
2119 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
2120
2121 // Create agent and connection
2122 let agent = NativeAgent::new(
2123 project.clone(),
2124 history_store,
2125 templates.clone(),
2126 None,
2127 fake_fs.clone(),
2128 &mut cx.to_async(),
2129 )
2130 .await
2131 .unwrap();
2132 let connection = NativeAgentConnection(agent.clone());
2133
2134 // Create a thread using new_thread
2135 let connection_rc = Rc::new(connection.clone());
2136 let acp_thread = cx
2137 .update(|cx| connection_rc.new_thread(project, cwd, cx))
2138 .await
2139 .expect("new_thread should succeed");
2140
2141 // Get the session_id from the AcpThread
2142 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2143
2144 // Test model_selector returns Some
2145 let selector_opt = connection.model_selector(&session_id);
2146 assert!(
2147 selector_opt.is_some(),
2148 "agent should always support ModelSelector"
2149 );
2150 let selector = selector_opt.unwrap();
2151
2152 // Test list_models
2153 let listed_models = cx
2154 .update(|cx| selector.list_models(cx))
2155 .await
2156 .expect("list_models should succeed");
2157 let AgentModelList::Grouped(listed_models) = listed_models else {
2158 panic!("Unexpected model list type");
2159 };
2160 assert!(!listed_models.is_empty(), "should have at least one model");
2161 assert_eq!(
2162 listed_models[&AgentModelGroupName("Fake".into())][0]
2163 .id
2164 .0
2165 .as_ref(),
2166 "fake/fake"
2167 );
2168
2169 // Test selected_model returns the default
2170 let model = cx
2171 .update(|cx| selector.selected_model(cx))
2172 .await
2173 .expect("selected_model should succeed");
2174 let model = cx
2175 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2176 .unwrap();
2177 let model = model.as_fake();
2178 assert_eq!(model.id().0, "fake", "should return default model");
2179
2180 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2181 cx.run_until_parked();
2182 model.send_last_completion_stream_text_chunk("def");
2183 cx.run_until_parked();
2184 acp_thread.read_with(cx, |thread, cx| {
2185 assert_eq!(
2186 thread.to_markdown(cx),
2187 indoc! {"
2188 ## User
2189
2190 abc
2191
2192 ## Assistant
2193
2194 def
2195
2196 "}
2197 )
2198 });
2199
2200 // Test cancel
2201 cx.update(|cx| connection.cancel(&session_id, cx));
2202 request.await.expect("prompt should fail gracefully");
2203
2204 // Ensure that dropping the ACP thread causes the native thread to be
2205 // dropped as well.
2206 cx.update(|_| drop(acp_thread));
2207 let result = cx
2208 .update(|cx| {
2209 connection.prompt(
2210 Some(acp_thread::UserMessageId::new()),
2211 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2212 cx,
2213 )
2214 })
2215 .await;
2216 assert_eq!(
2217 result.as_ref().unwrap_err().to_string(),
2218 "Session not found",
2219 "unexpected result: {:?}",
2220 result
2221 );
2222}
2223
2224#[gpui::test]
2225async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2226 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2227 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2228 let fake_model = model.as_fake();
2229
2230 let mut events = thread
2231 .update(cx, |thread, cx| {
2232 thread.send(UserMessageId::new(), ["Think"], cx)
2233 })
2234 .unwrap();
2235 cx.run_until_parked();
2236
2237 // Simulate streaming partial input.
2238 let input = json!({});
2239 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2240 LanguageModelToolUse {
2241 id: "1".into(),
2242 name: ThinkingTool::name().into(),
2243 raw_input: input.to_string(),
2244 input,
2245 is_input_complete: false,
2246 thought_signature: None,
2247 },
2248 ));
2249
2250 // Input streaming completed
2251 let input = json!({ "content": "Thinking hard!" });
2252 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2253 LanguageModelToolUse {
2254 id: "1".into(),
2255 name: "thinking".into(),
2256 raw_input: input.to_string(),
2257 input,
2258 is_input_complete: true,
2259 thought_signature: None,
2260 },
2261 ));
2262 fake_model.end_last_completion_stream();
2263 cx.run_until_parked();
2264
2265 let tool_call = expect_tool_call(&mut events).await;
2266 assert_eq!(
2267 tool_call,
2268 acp::ToolCall::new("1", "Thinking")
2269 .kind(acp::ToolKind::Think)
2270 .raw_input(json!({}))
2271 .meta(acp::Meta::from_iter([(
2272 "tool_name".into(),
2273 "thinking".into()
2274 )]))
2275 );
2276 let update = expect_tool_call_update_fields(&mut events).await;
2277 assert_eq!(
2278 update,
2279 acp::ToolCallUpdate::new(
2280 "1",
2281 acp::ToolCallUpdateFields::new()
2282 .title("Thinking")
2283 .kind(acp::ToolKind::Think)
2284 .raw_input(json!({ "content": "Thinking hard!"}))
2285 )
2286 );
2287 let update = expect_tool_call_update_fields(&mut events).await;
2288 assert_eq!(
2289 update,
2290 acp::ToolCallUpdate::new(
2291 "1",
2292 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
2293 )
2294 );
2295 let update = expect_tool_call_update_fields(&mut events).await;
2296 assert_eq!(
2297 update,
2298 acp::ToolCallUpdate::new(
2299 "1",
2300 acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
2301 )
2302 );
2303 let update = expect_tool_call_update_fields(&mut events).await;
2304 assert_eq!(
2305 update,
2306 acp::ToolCallUpdate::new(
2307 "1",
2308 acp::ToolCallUpdateFields::new()
2309 .status(acp::ToolCallStatus::Completed)
2310 .raw_output("Finished thinking.")
2311 )
2312 );
2313}
2314
2315#[gpui::test]
2316async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2317 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2318 let fake_model = model.as_fake();
2319
2320 let mut events = thread
2321 .update(cx, |thread, cx| {
2322 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2323 thread.send(UserMessageId::new(), ["Hello!"], cx)
2324 })
2325 .unwrap();
2326 cx.run_until_parked();
2327
2328 fake_model.send_last_completion_stream_text_chunk("Hey!");
2329 fake_model.end_last_completion_stream();
2330
2331 let mut retry_events = Vec::new();
2332 while let Some(Ok(event)) = events.next().await {
2333 match event {
2334 ThreadEvent::Retry(retry_status) => {
2335 retry_events.push(retry_status);
2336 }
2337 ThreadEvent::Stop(..) => break,
2338 _ => {}
2339 }
2340 }
2341
2342 assert_eq!(retry_events.len(), 0);
2343 thread.read_with(cx, |thread, _cx| {
2344 assert_eq!(
2345 thread.to_markdown(),
2346 indoc! {"
2347 ## User
2348
2349 Hello!
2350
2351 ## Assistant
2352
2353 Hey!
2354 "}
2355 )
2356 });
2357}
2358
2359#[gpui::test]
2360async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2361 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2362 let fake_model = model.as_fake();
2363
2364 let mut events = thread
2365 .update(cx, |thread, cx| {
2366 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2367 thread.send(UserMessageId::new(), ["Hello!"], cx)
2368 })
2369 .unwrap();
2370 cx.run_until_parked();
2371
2372 fake_model.send_last_completion_stream_text_chunk("Hey,");
2373 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2374 provider: LanguageModelProviderName::new("Anthropic"),
2375 retry_after: Some(Duration::from_secs(3)),
2376 });
2377 fake_model.end_last_completion_stream();
2378
2379 cx.executor().advance_clock(Duration::from_secs(3));
2380 cx.run_until_parked();
2381
2382 fake_model.send_last_completion_stream_text_chunk("there!");
2383 fake_model.end_last_completion_stream();
2384 cx.run_until_parked();
2385
2386 let mut retry_events = Vec::new();
2387 while let Some(Ok(event)) = events.next().await {
2388 match event {
2389 ThreadEvent::Retry(retry_status) => {
2390 retry_events.push(retry_status);
2391 }
2392 ThreadEvent::Stop(..) => break,
2393 _ => {}
2394 }
2395 }
2396
2397 assert_eq!(retry_events.len(), 1);
2398 assert!(matches!(
2399 retry_events[0],
2400 acp_thread::RetryStatus { attempt: 1, .. }
2401 ));
2402 thread.read_with(cx, |thread, _cx| {
2403 assert_eq!(
2404 thread.to_markdown(),
2405 indoc! {"
2406 ## User
2407
2408 Hello!
2409
2410 ## Assistant
2411
2412 Hey,
2413
2414 [resume]
2415
2416 ## Assistant
2417
2418 there!
2419 "}
2420 )
2421 });
2422}
2423
2424#[gpui::test]
2425async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2426 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2427 let fake_model = model.as_fake();
2428
2429 let events = thread
2430 .update(cx, |thread, cx| {
2431 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2432 thread.add_tool(EchoTool);
2433 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2434 })
2435 .unwrap();
2436 cx.run_until_parked();
2437
2438 let tool_use_1 = LanguageModelToolUse {
2439 id: "tool_1".into(),
2440 name: EchoTool::name().into(),
2441 raw_input: json!({"text": "test"}).to_string(),
2442 input: json!({"text": "test"}),
2443 is_input_complete: true,
2444 thought_signature: None,
2445 };
2446 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2447 tool_use_1.clone(),
2448 ));
2449 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2450 provider: LanguageModelProviderName::new("Anthropic"),
2451 retry_after: Some(Duration::from_secs(3)),
2452 });
2453 fake_model.end_last_completion_stream();
2454
2455 cx.executor().advance_clock(Duration::from_secs(3));
2456 let completion = fake_model.pending_completions().pop().unwrap();
2457 assert_eq!(
2458 completion.messages[1..],
2459 vec![
2460 LanguageModelRequestMessage {
2461 role: Role::User,
2462 content: vec!["Call the echo tool!".into()],
2463 cache: false,
2464 reasoning_details: None,
2465 },
2466 LanguageModelRequestMessage {
2467 role: Role::Assistant,
2468 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2469 cache: false,
2470 reasoning_details: None,
2471 },
2472 LanguageModelRequestMessage {
2473 role: Role::User,
2474 content: vec![language_model::MessageContent::ToolResult(
2475 LanguageModelToolResult {
2476 tool_use_id: tool_use_1.id.clone(),
2477 tool_name: tool_use_1.name.clone(),
2478 is_error: false,
2479 content: "test".into(),
2480 output: Some("test".into())
2481 }
2482 )],
2483 cache: true,
2484 reasoning_details: None,
2485 },
2486 ]
2487 );
2488
2489 fake_model.send_last_completion_stream_text_chunk("Done");
2490 fake_model.end_last_completion_stream();
2491 cx.run_until_parked();
2492 events.collect::<Vec<_>>().await;
2493 thread.read_with(cx, |thread, _cx| {
2494 assert_eq!(
2495 thread.last_message(),
2496 Some(Message::Agent(AgentMessage {
2497 content: vec![AgentMessageContent::Text("Done".into())],
2498 tool_results: IndexMap::default(),
2499 reasoning_details: None,
2500 }))
2501 );
2502 })
2503}
2504
2505#[gpui::test]
2506async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2507 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2508 let fake_model = model.as_fake();
2509
2510 let mut events = thread
2511 .update(cx, |thread, cx| {
2512 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2513 thread.send(UserMessageId::new(), ["Hello!"], cx)
2514 })
2515 .unwrap();
2516 cx.run_until_parked();
2517
2518 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2519 fake_model.send_last_completion_stream_error(
2520 LanguageModelCompletionError::ServerOverloaded {
2521 provider: LanguageModelProviderName::new("Anthropic"),
2522 retry_after: Some(Duration::from_secs(3)),
2523 },
2524 );
2525 fake_model.end_last_completion_stream();
2526 cx.executor().advance_clock(Duration::from_secs(3));
2527 cx.run_until_parked();
2528 }
2529
2530 let mut errors = Vec::new();
2531 let mut retry_events = Vec::new();
2532 while let Some(event) = events.next().await {
2533 match event {
2534 Ok(ThreadEvent::Retry(retry_status)) => {
2535 retry_events.push(retry_status);
2536 }
2537 Ok(ThreadEvent::Stop(..)) => break,
2538 Err(error) => errors.push(error),
2539 _ => {}
2540 }
2541 }
2542
2543 assert_eq!(
2544 retry_events.len(),
2545 crate::thread::MAX_RETRY_ATTEMPTS as usize
2546 );
2547 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2548 assert_eq!(retry_events[i].attempt, i + 1);
2549 }
2550 assert_eq!(errors.len(), 1);
2551 let error = errors[0]
2552 .downcast_ref::<LanguageModelCompletionError>()
2553 .unwrap();
2554 assert!(matches!(
2555 error,
2556 LanguageModelCompletionError::ServerOverloaded { .. }
2557 ));
2558}
2559
2560/// Filters out the stop events for asserting against in tests
2561fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2562 result_events
2563 .into_iter()
2564 .filter_map(|event| match event.unwrap() {
2565 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2566 _ => None,
2567 })
2568 .collect()
2569}
2570
2571struct ThreadTest {
2572 model: Arc<dyn LanguageModel>,
2573 thread: Entity<Thread>,
2574 project_context: Entity<ProjectContext>,
2575 context_server_store: Entity<ContextServerStore>,
2576 fs: Arc<FakeFs>,
2577}
2578
2579enum TestModel {
2580 Sonnet4,
2581 Fake,
2582}
2583
2584impl TestModel {
2585 fn id(&self) -> LanguageModelId {
2586 match self {
2587 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2588 TestModel::Fake => unreachable!(),
2589 }
2590 }
2591}
2592
2593async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2594 cx.executor().allow_parking();
2595
2596 let fs = FakeFs::new(cx.background_executor.clone());
2597 fs.create_dir(paths::settings_file().parent().unwrap())
2598 .await
2599 .unwrap();
2600 fs.insert_file(
2601 paths::settings_file(),
2602 json!({
2603 "agent": {
2604 "default_profile": "test-profile",
2605 "profiles": {
2606 "test-profile": {
2607 "name": "Test Profile",
2608 "tools": {
2609 EchoTool::name(): true,
2610 DelayTool::name(): true,
2611 WordListTool::name(): true,
2612 ToolRequiringPermission::name(): true,
2613 InfiniteTool::name(): true,
2614 ThinkingTool::name(): true,
2615 }
2616 }
2617 }
2618 }
2619 })
2620 .to_string()
2621 .into_bytes(),
2622 )
2623 .await;
2624
2625 cx.update(|cx| {
2626 settings::init(cx);
2627
2628 match model {
2629 TestModel::Fake => {}
2630 TestModel::Sonnet4 => {
2631 gpui_tokio::init(cx);
2632 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2633 cx.set_http_client(Arc::new(http_client));
2634 let client = Client::production(cx);
2635 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2636 language_model::init(client.clone(), cx);
2637 language_models::init(user_store, client.clone(), cx);
2638 }
2639 };
2640
2641 watch_settings(fs.clone(), cx);
2642 });
2643
2644 let templates = Templates::new();
2645
2646 fs.insert_tree(path!("/test"), json!({})).await;
2647 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2648
2649 let model = cx
2650 .update(|cx| {
2651 if let TestModel::Fake = model {
2652 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2653 } else {
2654 let model_id = model.id();
2655 let models = LanguageModelRegistry::read_global(cx);
2656 let model = models
2657 .available_models(cx)
2658 .find(|model| model.id() == model_id)
2659 .unwrap();
2660
2661 let provider = models.provider(&model.provider_id()).unwrap();
2662 let authenticated = provider.authenticate(cx);
2663
2664 cx.spawn(async move |_cx| {
2665 authenticated.await.unwrap();
2666 model
2667 })
2668 }
2669 })
2670 .await;
2671
2672 let project_context = cx.new(|_cx| ProjectContext::default());
2673 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2674 let context_server_registry =
2675 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2676 let thread = cx.new(|cx| {
2677 Thread::new(
2678 project,
2679 project_context.clone(),
2680 context_server_registry,
2681 templates,
2682 Some(model.clone()),
2683 cx,
2684 )
2685 });
2686 ThreadTest {
2687 model,
2688 thread,
2689 project_context,
2690 context_server_store,
2691 fs,
2692 }
2693}
2694
2695#[cfg(test)]
2696#[ctor::ctor]
2697fn init_logger() {
2698 if std::env::var("RUST_LOG").is_ok() {
2699 env_logger::init();
2700 }
2701}
2702
2703fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2704 let fs = fs.clone();
2705 cx.spawn({
2706 async move |cx| {
2707 let mut new_settings_content_rx = settings::watch_config_file(
2708 cx.background_executor(),
2709 fs,
2710 paths::settings_file().clone(),
2711 );
2712
2713 while let Some(new_settings_content) = new_settings_content_rx.next().await {
2714 cx.update(|cx| {
2715 SettingsStore::update_global(cx, |settings, cx| {
2716 settings.set_user_settings(&new_settings_content, cx)
2717 })
2718 })
2719 .ok();
2720 }
2721 }
2722 })
2723 .detach();
2724}
2725
2726fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2727 completion
2728 .tools
2729 .iter()
2730 .map(|tool| tool.name.clone())
2731 .collect()
2732}
2733
2734fn setup_context_server(
2735 name: &'static str,
2736 tools: Vec<context_server::types::Tool>,
2737 context_server_store: &Entity<ContextServerStore>,
2738 cx: &mut TestAppContext,
2739) -> mpsc::UnboundedReceiver<(
2740 context_server::types::CallToolParams,
2741 oneshot::Sender<context_server::types::CallToolResponse>,
2742)> {
2743 cx.update(|cx| {
2744 let mut settings = ProjectSettings::get_global(cx).clone();
2745 settings.context_servers.insert(
2746 name.into(),
2747 project::project_settings::ContextServerSettings::Stdio {
2748 enabled: true,
2749 command: ContextServerCommand {
2750 path: "somebinary".into(),
2751 args: Vec::new(),
2752 env: None,
2753 timeout: None,
2754 },
2755 },
2756 );
2757 ProjectSettings::override_global(settings, cx);
2758 });
2759
2760 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2761 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2762 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2763 context_server::types::InitializeResponse {
2764 protocol_version: context_server::types::ProtocolVersion(
2765 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2766 ),
2767 server_info: context_server::types::Implementation {
2768 name: name.into(),
2769 version: "1.0.0".to_string(),
2770 },
2771 capabilities: context_server::types::ServerCapabilities {
2772 tools: Some(context_server::types::ToolsCapabilities {
2773 list_changed: Some(true),
2774 }),
2775 ..Default::default()
2776 },
2777 meta: None,
2778 }
2779 })
2780 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2781 let tools = tools.clone();
2782 async move {
2783 context_server::types::ListToolsResponse {
2784 tools,
2785 next_cursor: None,
2786 meta: None,
2787 }
2788 }
2789 })
2790 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2791 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2792 async move {
2793 let (response_tx, response_rx) = oneshot::channel();
2794 mcp_tool_calls_tx
2795 .unbounded_send((params, response_tx))
2796 .unwrap();
2797 response_rx.await.unwrap()
2798 }
2799 });
2800 context_server_store.update(cx, |store, cx| {
2801 store.start_server(
2802 Arc::new(ContextServer::new(
2803 ContextServerId(name.into()),
2804 Arc::new(fake_transport),
2805 )),
2806 cx,
2807 );
2808 });
2809 cx.run_until_parked();
2810 mcp_tool_calls_rx
2811}
2812
2813#[gpui::test]
2814async fn test_tokens_before_message(cx: &mut TestAppContext) {
2815 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2816 let fake_model = model.as_fake();
2817
2818 // First message
2819 let message_1_id = UserMessageId::new();
2820 thread
2821 .update(cx, |thread, cx| {
2822 thread.send(message_1_id.clone(), ["First message"], cx)
2823 })
2824 .unwrap();
2825 cx.run_until_parked();
2826
2827 // Before any response, tokens_before_message should return None for first message
2828 thread.read_with(cx, |thread, _| {
2829 assert_eq!(
2830 thread.tokens_before_message(&message_1_id),
2831 None,
2832 "First message should have no tokens before it"
2833 );
2834 });
2835
2836 // Complete first message with usage
2837 fake_model.send_last_completion_stream_text_chunk("Response 1");
2838 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2839 language_model::TokenUsage {
2840 input_tokens: 100,
2841 output_tokens: 50,
2842 cache_creation_input_tokens: 0,
2843 cache_read_input_tokens: 0,
2844 },
2845 ));
2846 fake_model.end_last_completion_stream();
2847 cx.run_until_parked();
2848
2849 // First message still has no tokens before it
2850 thread.read_with(cx, |thread, _| {
2851 assert_eq!(
2852 thread.tokens_before_message(&message_1_id),
2853 None,
2854 "First message should still have no tokens before it after response"
2855 );
2856 });
2857
2858 // Second message
2859 let message_2_id = UserMessageId::new();
2860 thread
2861 .update(cx, |thread, cx| {
2862 thread.send(message_2_id.clone(), ["Second message"], cx)
2863 })
2864 .unwrap();
2865 cx.run_until_parked();
2866
2867 // Second message should have first message's input tokens before it
2868 thread.read_with(cx, |thread, _| {
2869 assert_eq!(
2870 thread.tokens_before_message(&message_2_id),
2871 Some(100),
2872 "Second message should have 100 tokens before it (from first request)"
2873 );
2874 });
2875
2876 // Complete second message
2877 fake_model.send_last_completion_stream_text_chunk("Response 2");
2878 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2879 language_model::TokenUsage {
2880 input_tokens: 250, // Total for this request (includes previous context)
2881 output_tokens: 75,
2882 cache_creation_input_tokens: 0,
2883 cache_read_input_tokens: 0,
2884 },
2885 ));
2886 fake_model.end_last_completion_stream();
2887 cx.run_until_parked();
2888
2889 // Third message
2890 let message_3_id = UserMessageId::new();
2891 thread
2892 .update(cx, |thread, cx| {
2893 thread.send(message_3_id.clone(), ["Third message"], cx)
2894 })
2895 .unwrap();
2896 cx.run_until_parked();
2897
2898 // Third message should have second message's input tokens (250) before it
2899 thread.read_with(cx, |thread, _| {
2900 assert_eq!(
2901 thread.tokens_before_message(&message_3_id),
2902 Some(250),
2903 "Third message should have 250 tokens before it (from second request)"
2904 );
2905 // Second message should still have 100
2906 assert_eq!(
2907 thread.tokens_before_message(&message_2_id),
2908 Some(100),
2909 "Second message should still have 100 tokens before it"
2910 );
2911 // First message still has none
2912 assert_eq!(
2913 thread.tokens_before_message(&message_1_id),
2914 None,
2915 "First message should still have no tokens before it"
2916 );
2917 });
2918}
2919
2920#[gpui::test]
2921async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
2922 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2923 let fake_model = model.as_fake();
2924
2925 // Set up three messages with responses
2926 let message_1_id = UserMessageId::new();
2927 thread
2928 .update(cx, |thread, cx| {
2929 thread.send(message_1_id.clone(), ["Message 1"], cx)
2930 })
2931 .unwrap();
2932 cx.run_until_parked();
2933 fake_model.send_last_completion_stream_text_chunk("Response 1");
2934 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2935 language_model::TokenUsage {
2936 input_tokens: 100,
2937 output_tokens: 50,
2938 cache_creation_input_tokens: 0,
2939 cache_read_input_tokens: 0,
2940 },
2941 ));
2942 fake_model.end_last_completion_stream();
2943 cx.run_until_parked();
2944
2945 let message_2_id = UserMessageId::new();
2946 thread
2947 .update(cx, |thread, cx| {
2948 thread.send(message_2_id.clone(), ["Message 2"], cx)
2949 })
2950 .unwrap();
2951 cx.run_until_parked();
2952 fake_model.send_last_completion_stream_text_chunk("Response 2");
2953 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2954 language_model::TokenUsage {
2955 input_tokens: 250,
2956 output_tokens: 75,
2957 cache_creation_input_tokens: 0,
2958 cache_read_input_tokens: 0,
2959 },
2960 ));
2961 fake_model.end_last_completion_stream();
2962 cx.run_until_parked();
2963
2964 // Verify initial state
2965 thread.read_with(cx, |thread, _| {
2966 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
2967 });
2968
2969 // Truncate at message 2 (removes message 2 and everything after)
2970 thread
2971 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
2972 .unwrap();
2973 cx.run_until_parked();
2974
2975 // After truncation, message_2_id no longer exists, so lookup should return None
2976 thread.read_with(cx, |thread, _| {
2977 assert_eq!(
2978 thread.tokens_before_message(&message_2_id),
2979 None,
2980 "After truncation, message 2 no longer exists"
2981 );
2982 // Message 1 still exists but has no tokens before it
2983 assert_eq!(
2984 thread.tokens_before_message(&message_1_id),
2985 None,
2986 "First message still has no tokens before it"
2987 );
2988 });
2989}