1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use fs::{FakeFs, Fs};
11use futures::{
12 FutureExt as _, StreamExt,
13 channel::{
14 mpsc::{self, UnboundedReceiver},
15 oneshot,
16 },
17 future::{Fuse, Shared},
18};
19use gpui::{
20 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
21 http_client::FakeHttpClient,
22};
23use indoc::indoc;
24use language_model::{
25 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
26 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
27 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
28 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
29};
30use pretty_assertions::assert_eq;
31use project::{
32 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
33};
34use prompt_store::ProjectContext;
35use reqwest_client::ReqwestClient;
36use schemars::JsonSchema;
37use serde::{Deserialize, Serialize};
38use serde_json::json;
39use settings::{Settings, SettingsStore};
40use std::{
41 path::Path,
42 pin::Pin,
43 rc::Rc,
44 sync::{
45 Arc,
46 atomic::{AtomicBool, Ordering},
47 },
48 time::Duration,
49};
50use util::path;
51
52mod test_tools;
53use test_tools::*;
54
55fn init_test(cx: &mut TestAppContext) {
56 cx.update(|cx| {
57 let settings_store = SettingsStore::test(cx);
58 cx.set_global(settings_store);
59 });
60}
61
62struct FakeTerminalHandle {
63 killed: Arc<AtomicBool>,
64 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
65 output: acp::TerminalOutputResponse,
66 id: acp::TerminalId,
67}
68
69impl FakeTerminalHandle {
70 fn new_never_exits(cx: &mut App) -> Self {
71 let killed = Arc::new(AtomicBool::new(false));
72
73 let killed_for_task = killed.clone();
74 let wait_for_exit = cx
75 .spawn(async move |cx| {
76 loop {
77 if killed_for_task.load(Ordering::SeqCst) {
78 return acp::TerminalExitStatus::new();
79 }
80 cx.background_executor()
81 .timer(Duration::from_millis(1))
82 .await;
83 }
84 })
85 .shared();
86
87 Self {
88 killed,
89 wait_for_exit,
90 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
91 id: acp::TerminalId::new("fake_terminal".to_string()),
92 }
93 }
94
95 fn was_killed(&self) -> bool {
96 self.killed.load(Ordering::SeqCst)
97 }
98}
99
100impl crate::TerminalHandle for FakeTerminalHandle {
101 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
102 Ok(self.id.clone())
103 }
104
105 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
106 Ok(self.output.clone())
107 }
108
109 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
110 Ok(self.wait_for_exit.clone())
111 }
112
113 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
114 self.killed.store(true, Ordering::SeqCst);
115 Ok(())
116 }
117
118 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
119 Ok(false)
120 }
121}
122
123struct FakeThreadEnvironment {
124 handle: Rc<FakeTerminalHandle>,
125}
126
127impl crate::ThreadEnvironment for FakeThreadEnvironment {
128 fn create_terminal(
129 &self,
130 _command: String,
131 _cwd: Option<std::path::PathBuf>,
132 _output_byte_limit: Option<u64>,
133 _cx: &mut AsyncApp,
134 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
135 Task::ready(Ok(self.handle.clone() as Rc<dyn crate::TerminalHandle>))
136 }
137}
138
139fn always_allow_tools(cx: &mut TestAppContext) {
140 cx.update(|cx| {
141 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
142 settings.always_allow_tool_actions = true;
143 agent_settings::AgentSettings::override_global(settings, cx);
144 });
145}
146
147#[gpui::test]
148async fn test_echo(cx: &mut TestAppContext) {
149 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
150 let fake_model = model.as_fake();
151
152 let events = thread
153 .update(cx, |thread, cx| {
154 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
155 })
156 .unwrap();
157 cx.run_until_parked();
158 fake_model.send_last_completion_stream_text_chunk("Hello");
159 fake_model
160 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
161 fake_model.end_last_completion_stream();
162
163 let events = events.collect().await;
164 thread.update(cx, |thread, _cx| {
165 assert_eq!(
166 thread.last_message().unwrap().to_markdown(),
167 indoc! {"
168 ## Assistant
169
170 Hello
171 "}
172 )
173 });
174 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
175}
176
177#[gpui::test]
178async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
179 init_test(cx);
180 always_allow_tools(cx);
181
182 let fs = FakeFs::new(cx.executor());
183 let project = Project::test(fs, [], cx).await;
184
185 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
186 let environment = Rc::new(FakeThreadEnvironment {
187 handle: handle.clone(),
188 });
189
190 #[allow(clippy::arc_with_non_send_sync)]
191 let tool = Arc::new(crate::TerminalTool::new(project, environment));
192 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
193
194 let task = cx.update(|cx| {
195 tool.run(
196 crate::TerminalToolInput {
197 command: "sleep 1000".to_string(),
198 cd: ".".to_string(),
199 timeout_ms: Some(5),
200 },
201 event_stream,
202 cx,
203 )
204 });
205
206 let update = rx.expect_update_fields().await;
207 assert!(
208 update.content.iter().any(|blocks| {
209 blocks
210 .iter()
211 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
212 }),
213 "expected tool call update to include terminal content"
214 );
215
216 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
217
218 let deadline = std::time::Instant::now() + Duration::from_millis(500);
219 loop {
220 if let Some(result) = task_future.as_mut().now_or_never() {
221 let result = result.expect("terminal tool task should complete");
222
223 assert!(
224 handle.was_killed(),
225 "expected terminal handle to be killed on timeout"
226 );
227 assert!(
228 result.contains("partial output"),
229 "expected result to include terminal output, got: {result}"
230 );
231 return;
232 }
233
234 if std::time::Instant::now() >= deadline {
235 panic!("timed out waiting for terminal tool task to complete");
236 }
237
238 cx.run_until_parked();
239 cx.background_executor.timer(Duration::from_millis(1)).await;
240 }
241}
242
243#[gpui::test]
244#[ignore]
245async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
246 init_test(cx);
247 always_allow_tools(cx);
248
249 let fs = FakeFs::new(cx.executor());
250 let project = Project::test(fs, [], cx).await;
251
252 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
253 let environment = Rc::new(FakeThreadEnvironment {
254 handle: handle.clone(),
255 });
256
257 #[allow(clippy::arc_with_non_send_sync)]
258 let tool = Arc::new(crate::TerminalTool::new(project, environment));
259 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
260
261 let _task = cx.update(|cx| {
262 tool.run(
263 crate::TerminalToolInput {
264 command: "sleep 1000".to_string(),
265 cd: ".".to_string(),
266 timeout_ms: None,
267 },
268 event_stream,
269 cx,
270 )
271 });
272
273 let update = rx.expect_update_fields().await;
274 assert!(
275 update.content.iter().any(|blocks| {
276 blocks
277 .iter()
278 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
279 }),
280 "expected tool call update to include terminal content"
281 );
282
283 smol::Timer::after(Duration::from_millis(25)).await;
284
285 assert!(
286 !handle.was_killed(),
287 "did not expect terminal handle to be killed without a timeout"
288 );
289}
290
291#[gpui::test]
292async fn test_thinking(cx: &mut TestAppContext) {
293 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
294 let fake_model = model.as_fake();
295
296 let events = thread
297 .update(cx, |thread, cx| {
298 thread.send(
299 UserMessageId::new(),
300 [indoc! {"
301 Testing:
302
303 Generate a thinking step where you just think the word 'Think',
304 and have your final answer be 'Hello'
305 "}],
306 cx,
307 )
308 })
309 .unwrap();
310 cx.run_until_parked();
311 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
312 text: "Think".to_string(),
313 signature: None,
314 });
315 fake_model.send_last_completion_stream_text_chunk("Hello");
316 fake_model
317 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
318 fake_model.end_last_completion_stream();
319
320 let events = events.collect().await;
321 thread.update(cx, |thread, _cx| {
322 assert_eq!(
323 thread.last_message().unwrap().to_markdown(),
324 indoc! {"
325 ## Assistant
326
327 <think>Think</think>
328 Hello
329 "}
330 )
331 });
332 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
333}
334
335#[gpui::test]
336async fn test_system_prompt(cx: &mut TestAppContext) {
337 let ThreadTest {
338 model,
339 thread,
340 project_context,
341 ..
342 } = setup(cx, TestModel::Fake).await;
343 let fake_model = model.as_fake();
344
345 project_context.update(cx, |project_context, _cx| {
346 project_context.shell = "test-shell".into()
347 });
348 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
349 thread
350 .update(cx, |thread, cx| {
351 thread.send(UserMessageId::new(), ["abc"], cx)
352 })
353 .unwrap();
354 cx.run_until_parked();
355 let mut pending_completions = fake_model.pending_completions();
356 assert_eq!(
357 pending_completions.len(),
358 1,
359 "unexpected pending completions: {:?}",
360 pending_completions
361 );
362
363 let pending_completion = pending_completions.pop().unwrap();
364 assert_eq!(pending_completion.messages[0].role, Role::System);
365
366 let system_message = &pending_completion.messages[0];
367 let system_prompt = system_message.content[0].to_str().unwrap();
368 assert!(
369 system_prompt.contains("test-shell"),
370 "unexpected system message: {:?}",
371 system_message
372 );
373 assert!(
374 system_prompt.contains("## Fixing Diagnostics"),
375 "unexpected system message: {:?}",
376 system_message
377 );
378}
379
380#[gpui::test]
381async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
382 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
383 let fake_model = model.as_fake();
384
385 thread
386 .update(cx, |thread, cx| {
387 thread.send(UserMessageId::new(), ["abc"], cx)
388 })
389 .unwrap();
390 cx.run_until_parked();
391 let mut pending_completions = fake_model.pending_completions();
392 assert_eq!(
393 pending_completions.len(),
394 1,
395 "unexpected pending completions: {:?}",
396 pending_completions
397 );
398
399 let pending_completion = pending_completions.pop().unwrap();
400 assert_eq!(pending_completion.messages[0].role, Role::System);
401
402 let system_message = &pending_completion.messages[0];
403 let system_prompt = system_message.content[0].to_str().unwrap();
404 assert!(
405 !system_prompt.contains("## Tool Use"),
406 "unexpected system message: {:?}",
407 system_message
408 );
409 assert!(
410 !system_prompt.contains("## Fixing Diagnostics"),
411 "unexpected system message: {:?}",
412 system_message
413 );
414}
415
416#[gpui::test]
417async fn test_prompt_caching(cx: &mut TestAppContext) {
418 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
419 let fake_model = model.as_fake();
420
421 // Send initial user message and verify it's cached
422 thread
423 .update(cx, |thread, cx| {
424 thread.send(UserMessageId::new(), ["Message 1"], cx)
425 })
426 .unwrap();
427 cx.run_until_parked();
428
429 let completion = fake_model.pending_completions().pop().unwrap();
430 assert_eq!(
431 completion.messages[1..],
432 vec![LanguageModelRequestMessage {
433 role: Role::User,
434 content: vec!["Message 1".into()],
435 cache: true,
436 reasoning_details: None,
437 }]
438 );
439 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
440 "Response to Message 1".into(),
441 ));
442 fake_model.end_last_completion_stream();
443 cx.run_until_parked();
444
445 // Send another user message and verify only the latest is cached
446 thread
447 .update(cx, |thread, cx| {
448 thread.send(UserMessageId::new(), ["Message 2"], cx)
449 })
450 .unwrap();
451 cx.run_until_parked();
452
453 let completion = fake_model.pending_completions().pop().unwrap();
454 assert_eq!(
455 completion.messages[1..],
456 vec![
457 LanguageModelRequestMessage {
458 role: Role::User,
459 content: vec!["Message 1".into()],
460 cache: false,
461 reasoning_details: None,
462 },
463 LanguageModelRequestMessage {
464 role: Role::Assistant,
465 content: vec!["Response to Message 1".into()],
466 cache: false,
467 reasoning_details: None,
468 },
469 LanguageModelRequestMessage {
470 role: Role::User,
471 content: vec!["Message 2".into()],
472 cache: true,
473 reasoning_details: None,
474 }
475 ]
476 );
477 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
478 "Response to Message 2".into(),
479 ));
480 fake_model.end_last_completion_stream();
481 cx.run_until_parked();
482
483 // Simulate a tool call and verify that the latest tool result is cached
484 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
485 thread
486 .update(cx, |thread, cx| {
487 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
488 })
489 .unwrap();
490 cx.run_until_parked();
491
492 let tool_use = LanguageModelToolUse {
493 id: "tool_1".into(),
494 name: EchoTool::name().into(),
495 raw_input: json!({"text": "test"}).to_string(),
496 input: json!({"text": "test"}),
497 is_input_complete: true,
498 thought_signature: None,
499 };
500 fake_model
501 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
502 fake_model.end_last_completion_stream();
503 cx.run_until_parked();
504
505 let completion = fake_model.pending_completions().pop().unwrap();
506 let tool_result = LanguageModelToolResult {
507 tool_use_id: "tool_1".into(),
508 tool_name: EchoTool::name().into(),
509 is_error: false,
510 content: "test".into(),
511 output: Some("test".into()),
512 };
513 assert_eq!(
514 completion.messages[1..],
515 vec![
516 LanguageModelRequestMessage {
517 role: Role::User,
518 content: vec!["Message 1".into()],
519 cache: false,
520 reasoning_details: None,
521 },
522 LanguageModelRequestMessage {
523 role: Role::Assistant,
524 content: vec!["Response to Message 1".into()],
525 cache: false,
526 reasoning_details: None,
527 },
528 LanguageModelRequestMessage {
529 role: Role::User,
530 content: vec!["Message 2".into()],
531 cache: false,
532 reasoning_details: None,
533 },
534 LanguageModelRequestMessage {
535 role: Role::Assistant,
536 content: vec!["Response to Message 2".into()],
537 cache: false,
538 reasoning_details: None,
539 },
540 LanguageModelRequestMessage {
541 role: Role::User,
542 content: vec!["Use the echo tool".into()],
543 cache: false,
544 reasoning_details: None,
545 },
546 LanguageModelRequestMessage {
547 role: Role::Assistant,
548 content: vec![MessageContent::ToolUse(tool_use)],
549 cache: false,
550 reasoning_details: None,
551 },
552 LanguageModelRequestMessage {
553 role: Role::User,
554 content: vec![MessageContent::ToolResult(tool_result)],
555 cache: true,
556 reasoning_details: None,
557 }
558 ]
559 );
560}
561
562#[gpui::test]
563#[cfg_attr(not(feature = "e2e"), ignore)]
564async fn test_basic_tool_calls(cx: &mut TestAppContext) {
565 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
566
567 // Test a tool call that's likely to complete *before* streaming stops.
568 let events = thread
569 .update(cx, |thread, cx| {
570 thread.add_tool(EchoTool);
571 thread.send(
572 UserMessageId::new(),
573 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
574 cx,
575 )
576 })
577 .unwrap()
578 .collect()
579 .await;
580 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
581
582 // Test a tool calls that's likely to complete *after* streaming stops.
583 let events = thread
584 .update(cx, |thread, cx| {
585 thread.remove_tool(&EchoTool::name());
586 thread.add_tool(DelayTool);
587 thread.send(
588 UserMessageId::new(),
589 [
590 "Now call the delay tool with 200ms.",
591 "When the timer goes off, then you echo the output of the tool.",
592 ],
593 cx,
594 )
595 })
596 .unwrap()
597 .collect()
598 .await;
599 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
600 thread.update(cx, |thread, _cx| {
601 assert!(
602 thread
603 .last_message()
604 .unwrap()
605 .as_agent_message()
606 .unwrap()
607 .content
608 .iter()
609 .any(|content| {
610 if let AgentMessageContent::Text(text) = content {
611 text.contains("Ding")
612 } else {
613 false
614 }
615 }),
616 "{}",
617 thread.to_markdown()
618 );
619 });
620}
621
622#[gpui::test]
623#[cfg_attr(not(feature = "e2e"), ignore)]
624async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
625 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
626
627 // Test a tool call that's likely to complete *before* streaming stops.
628 let mut events = thread
629 .update(cx, |thread, cx| {
630 thread.add_tool(WordListTool);
631 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
632 })
633 .unwrap();
634
635 let mut saw_partial_tool_use = false;
636 while let Some(event) = events.next().await {
637 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
638 thread.update(cx, |thread, _cx| {
639 // Look for a tool use in the thread's last message
640 let message = thread.last_message().unwrap();
641 let agent_message = message.as_agent_message().unwrap();
642 let last_content = agent_message.content.last().unwrap();
643 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
644 assert_eq!(last_tool_use.name.as_ref(), "word_list");
645 if tool_call.status == acp::ToolCallStatus::Pending {
646 if !last_tool_use.is_input_complete
647 && last_tool_use.input.get("g").is_none()
648 {
649 saw_partial_tool_use = true;
650 }
651 } else {
652 last_tool_use
653 .input
654 .get("a")
655 .expect("'a' has streamed because input is now complete");
656 last_tool_use
657 .input
658 .get("g")
659 .expect("'g' has streamed because input is now complete");
660 }
661 } else {
662 panic!("last content should be a tool use");
663 }
664 });
665 }
666 }
667
668 assert!(
669 saw_partial_tool_use,
670 "should see at least one partially streamed tool use in the history"
671 );
672}
673
674#[gpui::test]
675async fn test_tool_authorization(cx: &mut TestAppContext) {
676 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
677 let fake_model = model.as_fake();
678
679 let mut events = thread
680 .update(cx, |thread, cx| {
681 thread.add_tool(ToolRequiringPermission);
682 thread.send(UserMessageId::new(), ["abc"], cx)
683 })
684 .unwrap();
685 cx.run_until_parked();
686 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
687 LanguageModelToolUse {
688 id: "tool_id_1".into(),
689 name: ToolRequiringPermission::name().into(),
690 raw_input: "{}".into(),
691 input: json!({}),
692 is_input_complete: true,
693 thought_signature: None,
694 },
695 ));
696 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
697 LanguageModelToolUse {
698 id: "tool_id_2".into(),
699 name: ToolRequiringPermission::name().into(),
700 raw_input: "{}".into(),
701 input: json!({}),
702 is_input_complete: true,
703 thought_signature: None,
704 },
705 ));
706 fake_model.end_last_completion_stream();
707 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
708 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
709
710 // Approve the first
711 tool_call_auth_1
712 .response
713 .send(tool_call_auth_1.options[1].option_id.clone())
714 .unwrap();
715 cx.run_until_parked();
716
717 // Reject the second
718 tool_call_auth_2
719 .response
720 .send(tool_call_auth_1.options[2].option_id.clone())
721 .unwrap();
722 cx.run_until_parked();
723
724 let completion = fake_model.pending_completions().pop().unwrap();
725 let message = completion.messages.last().unwrap();
726 assert_eq!(
727 message.content,
728 vec![
729 language_model::MessageContent::ToolResult(LanguageModelToolResult {
730 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
731 tool_name: ToolRequiringPermission::name().into(),
732 is_error: false,
733 content: "Allowed".into(),
734 output: Some("Allowed".into())
735 }),
736 language_model::MessageContent::ToolResult(LanguageModelToolResult {
737 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
738 tool_name: ToolRequiringPermission::name().into(),
739 is_error: true,
740 content: "Permission to run tool denied by user".into(),
741 output: Some("Permission to run tool denied by user".into())
742 })
743 ]
744 );
745
746 // Simulate yet another tool call.
747 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
748 LanguageModelToolUse {
749 id: "tool_id_3".into(),
750 name: ToolRequiringPermission::name().into(),
751 raw_input: "{}".into(),
752 input: json!({}),
753 is_input_complete: true,
754 thought_signature: None,
755 },
756 ));
757 fake_model.end_last_completion_stream();
758
759 // Respond by always allowing tools.
760 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
761 tool_call_auth_3
762 .response
763 .send(tool_call_auth_3.options[0].option_id.clone())
764 .unwrap();
765 cx.run_until_parked();
766 let completion = fake_model.pending_completions().pop().unwrap();
767 let message = completion.messages.last().unwrap();
768 assert_eq!(
769 message.content,
770 vec![language_model::MessageContent::ToolResult(
771 LanguageModelToolResult {
772 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
773 tool_name: ToolRequiringPermission::name().into(),
774 is_error: false,
775 content: "Allowed".into(),
776 output: Some("Allowed".into())
777 }
778 )]
779 );
780
781 // Simulate a final tool call, ensuring we don't trigger authorization.
782 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
783 LanguageModelToolUse {
784 id: "tool_id_4".into(),
785 name: ToolRequiringPermission::name().into(),
786 raw_input: "{}".into(),
787 input: json!({}),
788 is_input_complete: true,
789 thought_signature: None,
790 },
791 ));
792 fake_model.end_last_completion_stream();
793 cx.run_until_parked();
794 let completion = fake_model.pending_completions().pop().unwrap();
795 let message = completion.messages.last().unwrap();
796 assert_eq!(
797 message.content,
798 vec![language_model::MessageContent::ToolResult(
799 LanguageModelToolResult {
800 tool_use_id: "tool_id_4".into(),
801 tool_name: ToolRequiringPermission::name().into(),
802 is_error: false,
803 content: "Allowed".into(),
804 output: Some("Allowed".into())
805 }
806 )]
807 );
808}
809
810#[gpui::test]
811async fn test_tool_hallucination(cx: &mut TestAppContext) {
812 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
813 let fake_model = model.as_fake();
814
815 let mut events = thread
816 .update(cx, |thread, cx| {
817 thread.send(UserMessageId::new(), ["abc"], cx)
818 })
819 .unwrap();
820 cx.run_until_parked();
821 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
822 LanguageModelToolUse {
823 id: "tool_id_1".into(),
824 name: "nonexistent_tool".into(),
825 raw_input: "{}".into(),
826 input: json!({}),
827 is_input_complete: true,
828 thought_signature: None,
829 },
830 ));
831 fake_model.end_last_completion_stream();
832
833 let tool_call = expect_tool_call(&mut events).await;
834 assert_eq!(tool_call.title, "nonexistent_tool");
835 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
836 let update = expect_tool_call_update_fields(&mut events).await;
837 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
838}
839
840#[gpui::test]
841async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
842 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
843 let fake_model = model.as_fake();
844
845 let events = thread
846 .update(cx, |thread, cx| {
847 thread.add_tool(EchoTool);
848 thread.send(UserMessageId::new(), ["abc"], cx)
849 })
850 .unwrap();
851 cx.run_until_parked();
852 let tool_use = LanguageModelToolUse {
853 id: "tool_id_1".into(),
854 name: EchoTool::name().into(),
855 raw_input: "{}".into(),
856 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
857 is_input_complete: true,
858 thought_signature: None,
859 };
860 fake_model
861 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
862 fake_model.end_last_completion_stream();
863
864 cx.run_until_parked();
865 let completion = fake_model.pending_completions().pop().unwrap();
866 let tool_result = LanguageModelToolResult {
867 tool_use_id: "tool_id_1".into(),
868 tool_name: EchoTool::name().into(),
869 is_error: false,
870 content: "def".into(),
871 output: Some("def".into()),
872 };
873 assert_eq!(
874 completion.messages[1..],
875 vec![
876 LanguageModelRequestMessage {
877 role: Role::User,
878 content: vec!["abc".into()],
879 cache: false,
880 reasoning_details: None,
881 },
882 LanguageModelRequestMessage {
883 role: Role::Assistant,
884 content: vec![MessageContent::ToolUse(tool_use.clone())],
885 cache: false,
886 reasoning_details: None,
887 },
888 LanguageModelRequestMessage {
889 role: Role::User,
890 content: vec![MessageContent::ToolResult(tool_result.clone())],
891 cache: true,
892 reasoning_details: None,
893 },
894 ]
895 );
896
897 // Simulate reaching tool use limit.
898 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
899 fake_model.end_last_completion_stream();
900 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
901 assert!(
902 last_event
903 .unwrap_err()
904 .is::<language_model::ToolUseLimitReachedError>()
905 );
906
907 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
908 cx.run_until_parked();
909 let completion = fake_model.pending_completions().pop().unwrap();
910 assert_eq!(
911 completion.messages[1..],
912 vec![
913 LanguageModelRequestMessage {
914 role: Role::User,
915 content: vec!["abc".into()],
916 cache: false,
917 reasoning_details: None,
918 },
919 LanguageModelRequestMessage {
920 role: Role::Assistant,
921 content: vec![MessageContent::ToolUse(tool_use)],
922 cache: false,
923 reasoning_details: None,
924 },
925 LanguageModelRequestMessage {
926 role: Role::User,
927 content: vec![MessageContent::ToolResult(tool_result)],
928 cache: false,
929 reasoning_details: None,
930 },
931 LanguageModelRequestMessage {
932 role: Role::User,
933 content: vec!["Continue where you left off".into()],
934 cache: true,
935 reasoning_details: None,
936 }
937 ]
938 );
939
940 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
941 fake_model.end_last_completion_stream();
942 events.collect::<Vec<_>>().await;
943 thread.read_with(cx, |thread, _cx| {
944 assert_eq!(
945 thread.last_message().unwrap().to_markdown(),
946 indoc! {"
947 ## Assistant
948
949 Done
950 "}
951 )
952 });
953}
954
955#[gpui::test]
956async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
957 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
958 let fake_model = model.as_fake();
959
960 let events = thread
961 .update(cx, |thread, cx| {
962 thread.add_tool(EchoTool);
963 thread.send(UserMessageId::new(), ["abc"], cx)
964 })
965 .unwrap();
966 cx.run_until_parked();
967
968 let tool_use = LanguageModelToolUse {
969 id: "tool_id_1".into(),
970 name: EchoTool::name().into(),
971 raw_input: "{}".into(),
972 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
973 is_input_complete: true,
974 thought_signature: None,
975 };
976 let tool_result = LanguageModelToolResult {
977 tool_use_id: "tool_id_1".into(),
978 tool_name: EchoTool::name().into(),
979 is_error: false,
980 content: "def".into(),
981 output: Some("def".into()),
982 };
983 fake_model
984 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
985 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
986 fake_model.end_last_completion_stream();
987 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
988 assert!(
989 last_event
990 .unwrap_err()
991 .is::<language_model::ToolUseLimitReachedError>()
992 );
993
994 thread
995 .update(cx, |thread, cx| {
996 thread.send(UserMessageId::new(), vec!["ghi"], cx)
997 })
998 .unwrap();
999 cx.run_until_parked();
1000 let completion = fake_model.pending_completions().pop().unwrap();
1001 assert_eq!(
1002 completion.messages[1..],
1003 vec![
1004 LanguageModelRequestMessage {
1005 role: Role::User,
1006 content: vec!["abc".into()],
1007 cache: false,
1008 reasoning_details: None,
1009 },
1010 LanguageModelRequestMessage {
1011 role: Role::Assistant,
1012 content: vec![MessageContent::ToolUse(tool_use)],
1013 cache: false,
1014 reasoning_details: None,
1015 },
1016 LanguageModelRequestMessage {
1017 role: Role::User,
1018 content: vec![MessageContent::ToolResult(tool_result)],
1019 cache: false,
1020 reasoning_details: None,
1021 },
1022 LanguageModelRequestMessage {
1023 role: Role::User,
1024 content: vec!["ghi".into()],
1025 cache: true,
1026 reasoning_details: None,
1027 }
1028 ]
1029 );
1030}
1031
1032async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
1033 let event = events
1034 .next()
1035 .await
1036 .expect("no tool call authorization event received")
1037 .unwrap();
1038 match event {
1039 ThreadEvent::ToolCall(tool_call) => tool_call,
1040 event => {
1041 panic!("Unexpected event {event:?}");
1042 }
1043 }
1044}
1045
1046async fn expect_tool_call_update_fields(
1047 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1048) -> acp::ToolCallUpdate {
1049 let event = events
1050 .next()
1051 .await
1052 .expect("no tool call authorization event received")
1053 .unwrap();
1054 match event {
1055 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
1056 event => {
1057 panic!("Unexpected event {event:?}");
1058 }
1059 }
1060}
1061
1062async fn next_tool_call_authorization(
1063 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1064) -> ToolCallAuthorization {
1065 loop {
1066 let event = events
1067 .next()
1068 .await
1069 .expect("no tool call authorization event received")
1070 .unwrap();
1071 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1072 let permission_kinds = tool_call_authorization
1073 .options
1074 .iter()
1075 .map(|o| o.kind)
1076 .collect::<Vec<_>>();
1077 assert_eq!(
1078 permission_kinds,
1079 vec![
1080 acp::PermissionOptionKind::AllowAlways,
1081 acp::PermissionOptionKind::AllowOnce,
1082 acp::PermissionOptionKind::RejectOnce,
1083 ]
1084 );
1085 return tool_call_authorization;
1086 }
1087 }
1088}
1089
1090#[gpui::test]
1091#[cfg_attr(not(feature = "e2e"), ignore)]
1092async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1093 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1094
1095 // Test concurrent tool calls with different delay times
1096 let events = thread
1097 .update(cx, |thread, cx| {
1098 thread.add_tool(DelayTool);
1099 thread.send(
1100 UserMessageId::new(),
1101 [
1102 "Call the delay tool twice in the same message.",
1103 "Once with 100ms. Once with 300ms.",
1104 "When both timers are complete, describe the outputs.",
1105 ],
1106 cx,
1107 )
1108 })
1109 .unwrap()
1110 .collect()
1111 .await;
1112
1113 let stop_reasons = stop_events(events);
1114 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1115
1116 thread.update(cx, |thread, _cx| {
1117 let last_message = thread.last_message().unwrap();
1118 let agent_message = last_message.as_agent_message().unwrap();
1119 let text = agent_message
1120 .content
1121 .iter()
1122 .filter_map(|content| {
1123 if let AgentMessageContent::Text(text) = content {
1124 Some(text.as_str())
1125 } else {
1126 None
1127 }
1128 })
1129 .collect::<String>();
1130
1131 assert!(text.contains("Ding"));
1132 });
1133}
1134
1135#[gpui::test]
1136async fn test_profiles(cx: &mut TestAppContext) {
1137 let ThreadTest {
1138 model, thread, fs, ..
1139 } = setup(cx, TestModel::Fake).await;
1140 let fake_model = model.as_fake();
1141
1142 thread.update(cx, |thread, _cx| {
1143 thread.add_tool(DelayTool);
1144 thread.add_tool(EchoTool);
1145 thread.add_tool(InfiniteTool);
1146 });
1147
1148 // Override profiles and wait for settings to be loaded.
1149 fs.insert_file(
1150 paths::settings_file(),
1151 json!({
1152 "agent": {
1153 "profiles": {
1154 "test-1": {
1155 "name": "Test Profile 1",
1156 "tools": {
1157 EchoTool::name(): true,
1158 DelayTool::name(): true,
1159 }
1160 },
1161 "test-2": {
1162 "name": "Test Profile 2",
1163 "tools": {
1164 InfiniteTool::name(): true,
1165 }
1166 }
1167 }
1168 }
1169 })
1170 .to_string()
1171 .into_bytes(),
1172 )
1173 .await;
1174 cx.run_until_parked();
1175
1176 // Test that test-1 profile (default) has echo and delay tools
1177 thread
1178 .update(cx, |thread, cx| {
1179 thread.set_profile(AgentProfileId("test-1".into()), cx);
1180 thread.send(UserMessageId::new(), ["test"], cx)
1181 })
1182 .unwrap();
1183 cx.run_until_parked();
1184
1185 let mut pending_completions = fake_model.pending_completions();
1186 assert_eq!(pending_completions.len(), 1);
1187 let completion = pending_completions.pop().unwrap();
1188 let tool_names: Vec<String> = completion
1189 .tools
1190 .iter()
1191 .map(|tool| tool.name.clone())
1192 .collect();
1193 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
1194 fake_model.end_last_completion_stream();
1195
1196 // Switch to test-2 profile, and verify that it has only the infinite tool.
1197 thread
1198 .update(cx, |thread, cx| {
1199 thread.set_profile(AgentProfileId("test-2".into()), cx);
1200 thread.send(UserMessageId::new(), ["test2"], cx)
1201 })
1202 .unwrap();
1203 cx.run_until_parked();
1204 let mut pending_completions = fake_model.pending_completions();
1205 assert_eq!(pending_completions.len(), 1);
1206 let completion = pending_completions.pop().unwrap();
1207 let tool_names: Vec<String> = completion
1208 .tools
1209 .iter()
1210 .map(|tool| tool.name.clone())
1211 .collect();
1212 assert_eq!(tool_names, vec![InfiniteTool::name()]);
1213}
1214
1215#[gpui::test]
1216async fn test_mcp_tools(cx: &mut TestAppContext) {
1217 let ThreadTest {
1218 model,
1219 thread,
1220 context_server_store,
1221 fs,
1222 ..
1223 } = setup(cx, TestModel::Fake).await;
1224 let fake_model = model.as_fake();
1225
1226 // Override profiles and wait for settings to be loaded.
1227 fs.insert_file(
1228 paths::settings_file(),
1229 json!({
1230 "agent": {
1231 "always_allow_tool_actions": true,
1232 "profiles": {
1233 "test": {
1234 "name": "Test Profile",
1235 "enable_all_context_servers": true,
1236 "tools": {
1237 EchoTool::name(): true,
1238 }
1239 },
1240 }
1241 }
1242 })
1243 .to_string()
1244 .into_bytes(),
1245 )
1246 .await;
1247 cx.run_until_parked();
1248 thread.update(cx, |thread, cx| {
1249 thread.set_profile(AgentProfileId("test".into()), cx)
1250 });
1251
1252 let mut mcp_tool_calls = setup_context_server(
1253 "test_server",
1254 vec![context_server::types::Tool {
1255 name: "echo".into(),
1256 description: None,
1257 input_schema: serde_json::to_value(EchoTool::input_schema(
1258 LanguageModelToolSchemaFormat::JsonSchema,
1259 ))
1260 .unwrap(),
1261 output_schema: None,
1262 annotations: None,
1263 }],
1264 &context_server_store,
1265 cx,
1266 );
1267
1268 let events = thread.update(cx, |thread, cx| {
1269 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1270 });
1271 cx.run_until_parked();
1272
1273 // Simulate the model calling the MCP tool.
1274 let completion = fake_model.pending_completions().pop().unwrap();
1275 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1276 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1277 LanguageModelToolUse {
1278 id: "tool_1".into(),
1279 name: "echo".into(),
1280 raw_input: json!({"text": "test"}).to_string(),
1281 input: json!({"text": "test"}),
1282 is_input_complete: true,
1283 thought_signature: None,
1284 },
1285 ));
1286 fake_model.end_last_completion_stream();
1287 cx.run_until_parked();
1288
1289 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1290 assert_eq!(tool_call_params.name, "echo");
1291 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1292 tool_call_response
1293 .send(context_server::types::CallToolResponse {
1294 content: vec![context_server::types::ToolResponseContent::Text {
1295 text: "test".into(),
1296 }],
1297 is_error: None,
1298 meta: None,
1299 structured_content: None,
1300 })
1301 .unwrap();
1302 cx.run_until_parked();
1303
1304 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1305 fake_model.send_last_completion_stream_text_chunk("Done!");
1306 fake_model.end_last_completion_stream();
1307 events.collect::<Vec<_>>().await;
1308
1309 // Send again after adding the echo tool, ensuring the name collision is resolved.
1310 let events = thread.update(cx, |thread, cx| {
1311 thread.add_tool(EchoTool);
1312 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1313 });
1314 cx.run_until_parked();
1315 let completion = fake_model.pending_completions().pop().unwrap();
1316 assert_eq!(
1317 tool_names_for_completion(&completion),
1318 vec!["echo", "test_server_echo"]
1319 );
1320 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1321 LanguageModelToolUse {
1322 id: "tool_2".into(),
1323 name: "test_server_echo".into(),
1324 raw_input: json!({"text": "mcp"}).to_string(),
1325 input: json!({"text": "mcp"}),
1326 is_input_complete: true,
1327 thought_signature: None,
1328 },
1329 ));
1330 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1331 LanguageModelToolUse {
1332 id: "tool_3".into(),
1333 name: "echo".into(),
1334 raw_input: json!({"text": "native"}).to_string(),
1335 input: json!({"text": "native"}),
1336 is_input_complete: true,
1337 thought_signature: None,
1338 },
1339 ));
1340 fake_model.end_last_completion_stream();
1341 cx.run_until_parked();
1342
1343 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1344 assert_eq!(tool_call_params.name, "echo");
1345 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1346 tool_call_response
1347 .send(context_server::types::CallToolResponse {
1348 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1349 is_error: None,
1350 meta: None,
1351 structured_content: None,
1352 })
1353 .unwrap();
1354 cx.run_until_parked();
1355
1356 // Ensure the tool results were inserted with the correct names.
1357 let completion = fake_model.pending_completions().pop().unwrap();
1358 assert_eq!(
1359 completion.messages.last().unwrap().content,
1360 vec![
1361 MessageContent::ToolResult(LanguageModelToolResult {
1362 tool_use_id: "tool_3".into(),
1363 tool_name: "echo".into(),
1364 is_error: false,
1365 content: "native".into(),
1366 output: Some("native".into()),
1367 },),
1368 MessageContent::ToolResult(LanguageModelToolResult {
1369 tool_use_id: "tool_2".into(),
1370 tool_name: "test_server_echo".into(),
1371 is_error: false,
1372 content: "mcp".into(),
1373 output: Some("mcp".into()),
1374 },),
1375 ]
1376 );
1377 fake_model.end_last_completion_stream();
1378 events.collect::<Vec<_>>().await;
1379}
1380
1381#[gpui::test]
1382async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1383 let ThreadTest {
1384 model,
1385 thread,
1386 context_server_store,
1387 fs,
1388 ..
1389 } = setup(cx, TestModel::Fake).await;
1390 let fake_model = model.as_fake();
1391
1392 // Set up a profile with all tools enabled
1393 fs.insert_file(
1394 paths::settings_file(),
1395 json!({
1396 "agent": {
1397 "profiles": {
1398 "test": {
1399 "name": "Test Profile",
1400 "enable_all_context_servers": true,
1401 "tools": {
1402 EchoTool::name(): true,
1403 DelayTool::name(): true,
1404 WordListTool::name(): true,
1405 ToolRequiringPermission::name(): true,
1406 InfiniteTool::name(): true,
1407 }
1408 },
1409 }
1410 }
1411 })
1412 .to_string()
1413 .into_bytes(),
1414 )
1415 .await;
1416 cx.run_until_parked();
1417
1418 thread.update(cx, |thread, cx| {
1419 thread.set_profile(AgentProfileId("test".into()), cx);
1420 thread.add_tool(EchoTool);
1421 thread.add_tool(DelayTool);
1422 thread.add_tool(WordListTool);
1423 thread.add_tool(ToolRequiringPermission);
1424 thread.add_tool(InfiniteTool);
1425 });
1426
1427 // Set up multiple context servers with some overlapping tool names
1428 let _server1_calls = setup_context_server(
1429 "xxx",
1430 vec![
1431 context_server::types::Tool {
1432 name: "echo".into(), // Conflicts with native EchoTool
1433 description: None,
1434 input_schema: serde_json::to_value(EchoTool::input_schema(
1435 LanguageModelToolSchemaFormat::JsonSchema,
1436 ))
1437 .unwrap(),
1438 output_schema: None,
1439 annotations: None,
1440 },
1441 context_server::types::Tool {
1442 name: "unique_tool_1".into(),
1443 description: None,
1444 input_schema: json!({"type": "object", "properties": {}}),
1445 output_schema: None,
1446 annotations: None,
1447 },
1448 ],
1449 &context_server_store,
1450 cx,
1451 );
1452
1453 let _server2_calls = setup_context_server(
1454 "yyy",
1455 vec![
1456 context_server::types::Tool {
1457 name: "echo".into(), // Also conflicts with native EchoTool
1458 description: None,
1459 input_schema: serde_json::to_value(EchoTool::input_schema(
1460 LanguageModelToolSchemaFormat::JsonSchema,
1461 ))
1462 .unwrap(),
1463 output_schema: None,
1464 annotations: None,
1465 },
1466 context_server::types::Tool {
1467 name: "unique_tool_2".into(),
1468 description: None,
1469 input_schema: json!({"type": "object", "properties": {}}),
1470 output_schema: None,
1471 annotations: None,
1472 },
1473 context_server::types::Tool {
1474 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1475 description: None,
1476 input_schema: json!({"type": "object", "properties": {}}),
1477 output_schema: None,
1478 annotations: None,
1479 },
1480 context_server::types::Tool {
1481 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1482 description: None,
1483 input_schema: json!({"type": "object", "properties": {}}),
1484 output_schema: None,
1485 annotations: None,
1486 },
1487 ],
1488 &context_server_store,
1489 cx,
1490 );
1491 let _server3_calls = setup_context_server(
1492 "zzz",
1493 vec![
1494 context_server::types::Tool {
1495 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1496 description: None,
1497 input_schema: json!({"type": "object", "properties": {}}),
1498 output_schema: None,
1499 annotations: None,
1500 },
1501 context_server::types::Tool {
1502 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1503 description: None,
1504 input_schema: json!({"type": "object", "properties": {}}),
1505 output_schema: None,
1506 annotations: None,
1507 },
1508 context_server::types::Tool {
1509 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1510 description: None,
1511 input_schema: json!({"type": "object", "properties": {}}),
1512 output_schema: None,
1513 annotations: None,
1514 },
1515 ],
1516 &context_server_store,
1517 cx,
1518 );
1519
1520 thread
1521 .update(cx, |thread, cx| {
1522 thread.send(UserMessageId::new(), ["Go"], cx)
1523 })
1524 .unwrap();
1525 cx.run_until_parked();
1526 let completion = fake_model.pending_completions().pop().unwrap();
1527 assert_eq!(
1528 tool_names_for_completion(&completion),
1529 vec![
1530 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1531 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1532 "delay",
1533 "echo",
1534 "infinite",
1535 "tool_requiring_permission",
1536 "unique_tool_1",
1537 "unique_tool_2",
1538 "word_list",
1539 "xxx_echo",
1540 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1541 "yyy_echo",
1542 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1543 ]
1544 );
1545}
1546
1547#[gpui::test]
1548#[cfg_attr(not(feature = "e2e"), ignore)]
1549async fn test_cancellation(cx: &mut TestAppContext) {
1550 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1551
1552 let mut events = thread
1553 .update(cx, |thread, cx| {
1554 thread.add_tool(InfiniteTool);
1555 thread.add_tool(EchoTool);
1556 thread.send(
1557 UserMessageId::new(),
1558 ["Call the echo tool, then call the infinite tool, then explain their output"],
1559 cx,
1560 )
1561 })
1562 .unwrap();
1563
1564 // Wait until both tools are called.
1565 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1566 let mut echo_id = None;
1567 let mut echo_completed = false;
1568 while let Some(event) = events.next().await {
1569 match event.unwrap() {
1570 ThreadEvent::ToolCall(tool_call) => {
1571 assert_eq!(tool_call.title, expected_tools.remove(0));
1572 if tool_call.title == "Echo" {
1573 echo_id = Some(tool_call.tool_call_id);
1574 }
1575 }
1576 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1577 acp::ToolCallUpdate {
1578 tool_call_id,
1579 fields:
1580 acp::ToolCallUpdateFields {
1581 status: Some(acp::ToolCallStatus::Completed),
1582 ..
1583 },
1584 ..
1585 },
1586 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1587 echo_completed = true;
1588 }
1589 _ => {}
1590 }
1591
1592 if expected_tools.is_empty() && echo_completed {
1593 break;
1594 }
1595 }
1596
1597 // Cancel the current send and ensure that the event stream is closed, even
1598 // if one of the tools is still running.
1599 thread.update(cx, |thread, cx| thread.cancel(cx));
1600 let events = events.collect::<Vec<_>>().await;
1601 let last_event = events.last();
1602 assert!(
1603 matches!(
1604 last_event,
1605 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1606 ),
1607 "unexpected event {last_event:?}"
1608 );
1609
1610 // Ensure we can still send a new message after cancellation.
1611 let events = thread
1612 .update(cx, |thread, cx| {
1613 thread.send(
1614 UserMessageId::new(),
1615 ["Testing: reply with 'Hello' then stop."],
1616 cx,
1617 )
1618 })
1619 .unwrap()
1620 .collect::<Vec<_>>()
1621 .await;
1622 thread.update(cx, |thread, _cx| {
1623 let message = thread.last_message().unwrap();
1624 let agent_message = message.as_agent_message().unwrap();
1625 assert_eq!(
1626 agent_message.content,
1627 vec![AgentMessageContent::Text("Hello".to_string())]
1628 );
1629 });
1630 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1631}
1632
1633#[gpui::test]
1634async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1635 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1636 let fake_model = model.as_fake();
1637
1638 let events_1 = thread
1639 .update(cx, |thread, cx| {
1640 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1641 })
1642 .unwrap();
1643 cx.run_until_parked();
1644 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1645 cx.run_until_parked();
1646
1647 let events_2 = thread
1648 .update(cx, |thread, cx| {
1649 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1650 })
1651 .unwrap();
1652 cx.run_until_parked();
1653 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1654 fake_model
1655 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1656 fake_model.end_last_completion_stream();
1657
1658 let events_1 = events_1.collect::<Vec<_>>().await;
1659 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1660 let events_2 = events_2.collect::<Vec<_>>().await;
1661 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1662}
1663
1664#[gpui::test]
1665async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1666 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1667 let fake_model = model.as_fake();
1668
1669 let events_1 = thread
1670 .update(cx, |thread, cx| {
1671 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1672 })
1673 .unwrap();
1674 cx.run_until_parked();
1675 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1676 fake_model
1677 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1678 fake_model.end_last_completion_stream();
1679 let events_1 = events_1.collect::<Vec<_>>().await;
1680
1681 let events_2 = thread
1682 .update(cx, |thread, cx| {
1683 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1684 })
1685 .unwrap();
1686 cx.run_until_parked();
1687 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1688 fake_model
1689 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1690 fake_model.end_last_completion_stream();
1691 let events_2 = events_2.collect::<Vec<_>>().await;
1692
1693 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1694 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1695}
1696
1697#[gpui::test]
1698async fn test_refusal(cx: &mut TestAppContext) {
1699 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1700 let fake_model = model.as_fake();
1701
1702 let events = thread
1703 .update(cx, |thread, cx| {
1704 thread.send(UserMessageId::new(), ["Hello"], cx)
1705 })
1706 .unwrap();
1707 cx.run_until_parked();
1708 thread.read_with(cx, |thread, _| {
1709 assert_eq!(
1710 thread.to_markdown(),
1711 indoc! {"
1712 ## User
1713
1714 Hello
1715 "}
1716 );
1717 });
1718
1719 fake_model.send_last_completion_stream_text_chunk("Hey!");
1720 cx.run_until_parked();
1721 thread.read_with(cx, |thread, _| {
1722 assert_eq!(
1723 thread.to_markdown(),
1724 indoc! {"
1725 ## User
1726
1727 Hello
1728
1729 ## Assistant
1730
1731 Hey!
1732 "}
1733 );
1734 });
1735
1736 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1737 fake_model
1738 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1739 let events = events.collect::<Vec<_>>().await;
1740 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1741 thread.read_with(cx, |thread, _| {
1742 assert_eq!(thread.to_markdown(), "");
1743 });
1744}
1745
1746#[gpui::test]
1747async fn test_truncate_first_message(cx: &mut TestAppContext) {
1748 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1749 let fake_model = model.as_fake();
1750
1751 let message_id = UserMessageId::new();
1752 thread
1753 .update(cx, |thread, cx| {
1754 thread.send(message_id.clone(), ["Hello"], cx)
1755 })
1756 .unwrap();
1757 cx.run_until_parked();
1758 thread.read_with(cx, |thread, _| {
1759 assert_eq!(
1760 thread.to_markdown(),
1761 indoc! {"
1762 ## User
1763
1764 Hello
1765 "}
1766 );
1767 assert_eq!(thread.latest_token_usage(), None);
1768 });
1769
1770 fake_model.send_last_completion_stream_text_chunk("Hey!");
1771 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1772 language_model::TokenUsage {
1773 input_tokens: 32_000,
1774 output_tokens: 16_000,
1775 cache_creation_input_tokens: 0,
1776 cache_read_input_tokens: 0,
1777 },
1778 ));
1779 cx.run_until_parked();
1780 thread.read_with(cx, |thread, _| {
1781 assert_eq!(
1782 thread.to_markdown(),
1783 indoc! {"
1784 ## User
1785
1786 Hello
1787
1788 ## Assistant
1789
1790 Hey!
1791 "}
1792 );
1793 assert_eq!(
1794 thread.latest_token_usage(),
1795 Some(acp_thread::TokenUsage {
1796 used_tokens: 32_000 + 16_000,
1797 max_tokens: 1_000_000,
1798 })
1799 );
1800 });
1801
1802 thread
1803 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1804 .unwrap();
1805 cx.run_until_parked();
1806 thread.read_with(cx, |thread, _| {
1807 assert_eq!(thread.to_markdown(), "");
1808 assert_eq!(thread.latest_token_usage(), None);
1809 });
1810
1811 // Ensure we can still send a new message after truncation.
1812 thread
1813 .update(cx, |thread, cx| {
1814 thread.send(UserMessageId::new(), ["Hi"], cx)
1815 })
1816 .unwrap();
1817 thread.update(cx, |thread, _cx| {
1818 assert_eq!(
1819 thread.to_markdown(),
1820 indoc! {"
1821 ## User
1822
1823 Hi
1824 "}
1825 );
1826 });
1827 cx.run_until_parked();
1828 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1829 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1830 language_model::TokenUsage {
1831 input_tokens: 40_000,
1832 output_tokens: 20_000,
1833 cache_creation_input_tokens: 0,
1834 cache_read_input_tokens: 0,
1835 },
1836 ));
1837 cx.run_until_parked();
1838 thread.read_with(cx, |thread, _| {
1839 assert_eq!(
1840 thread.to_markdown(),
1841 indoc! {"
1842 ## User
1843
1844 Hi
1845
1846 ## Assistant
1847
1848 Ahoy!
1849 "}
1850 );
1851
1852 assert_eq!(
1853 thread.latest_token_usage(),
1854 Some(acp_thread::TokenUsage {
1855 used_tokens: 40_000 + 20_000,
1856 max_tokens: 1_000_000,
1857 })
1858 );
1859 });
1860}
1861
1862#[gpui::test]
1863async fn test_truncate_second_message(cx: &mut TestAppContext) {
1864 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1865 let fake_model = model.as_fake();
1866
1867 thread
1868 .update(cx, |thread, cx| {
1869 thread.send(UserMessageId::new(), ["Message 1"], cx)
1870 })
1871 .unwrap();
1872 cx.run_until_parked();
1873 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1874 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1875 language_model::TokenUsage {
1876 input_tokens: 32_000,
1877 output_tokens: 16_000,
1878 cache_creation_input_tokens: 0,
1879 cache_read_input_tokens: 0,
1880 },
1881 ));
1882 fake_model.end_last_completion_stream();
1883 cx.run_until_parked();
1884
1885 let assert_first_message_state = |cx: &mut TestAppContext| {
1886 thread.clone().read_with(cx, |thread, _| {
1887 assert_eq!(
1888 thread.to_markdown(),
1889 indoc! {"
1890 ## User
1891
1892 Message 1
1893
1894 ## Assistant
1895
1896 Message 1 response
1897 "}
1898 );
1899
1900 assert_eq!(
1901 thread.latest_token_usage(),
1902 Some(acp_thread::TokenUsage {
1903 used_tokens: 32_000 + 16_000,
1904 max_tokens: 1_000_000,
1905 })
1906 );
1907 });
1908 };
1909
1910 assert_first_message_state(cx);
1911
1912 let second_message_id = UserMessageId::new();
1913 thread
1914 .update(cx, |thread, cx| {
1915 thread.send(second_message_id.clone(), ["Message 2"], cx)
1916 })
1917 .unwrap();
1918 cx.run_until_parked();
1919
1920 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1921 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1922 language_model::TokenUsage {
1923 input_tokens: 40_000,
1924 output_tokens: 20_000,
1925 cache_creation_input_tokens: 0,
1926 cache_read_input_tokens: 0,
1927 },
1928 ));
1929 fake_model.end_last_completion_stream();
1930 cx.run_until_parked();
1931
1932 thread.read_with(cx, |thread, _| {
1933 assert_eq!(
1934 thread.to_markdown(),
1935 indoc! {"
1936 ## User
1937
1938 Message 1
1939
1940 ## Assistant
1941
1942 Message 1 response
1943
1944 ## User
1945
1946 Message 2
1947
1948 ## Assistant
1949
1950 Message 2 response
1951 "}
1952 );
1953
1954 assert_eq!(
1955 thread.latest_token_usage(),
1956 Some(acp_thread::TokenUsage {
1957 used_tokens: 40_000 + 20_000,
1958 max_tokens: 1_000_000,
1959 })
1960 );
1961 });
1962
1963 thread
1964 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1965 .unwrap();
1966 cx.run_until_parked();
1967
1968 assert_first_message_state(cx);
1969}
1970
1971#[gpui::test]
1972async fn test_title_generation(cx: &mut TestAppContext) {
1973 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1974 let fake_model = model.as_fake();
1975
1976 let summary_model = Arc::new(FakeLanguageModel::default());
1977 thread.update(cx, |thread, cx| {
1978 thread.set_summarization_model(Some(summary_model.clone()), cx)
1979 });
1980
1981 let send = thread
1982 .update(cx, |thread, cx| {
1983 thread.send(UserMessageId::new(), ["Hello"], cx)
1984 })
1985 .unwrap();
1986 cx.run_until_parked();
1987
1988 fake_model.send_last_completion_stream_text_chunk("Hey!");
1989 fake_model.end_last_completion_stream();
1990 cx.run_until_parked();
1991 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1992
1993 // Ensure the summary model has been invoked to generate a title.
1994 summary_model.send_last_completion_stream_text_chunk("Hello ");
1995 summary_model.send_last_completion_stream_text_chunk("world\nG");
1996 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1997 summary_model.end_last_completion_stream();
1998 send.collect::<Vec<_>>().await;
1999 cx.run_until_parked();
2000 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2001
2002 // Send another message, ensuring no title is generated this time.
2003 let send = thread
2004 .update(cx, |thread, cx| {
2005 thread.send(UserMessageId::new(), ["Hello again"], cx)
2006 })
2007 .unwrap();
2008 cx.run_until_parked();
2009 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2010 fake_model.end_last_completion_stream();
2011 cx.run_until_parked();
2012 assert_eq!(summary_model.pending_completions(), Vec::new());
2013 send.collect::<Vec<_>>().await;
2014 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2015}
2016
2017#[gpui::test]
2018async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2019 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2020 let fake_model = model.as_fake();
2021
2022 let _events = thread
2023 .update(cx, |thread, cx| {
2024 thread.add_tool(ToolRequiringPermission);
2025 thread.add_tool(EchoTool);
2026 thread.send(UserMessageId::new(), ["Hey!"], cx)
2027 })
2028 .unwrap();
2029 cx.run_until_parked();
2030
2031 let permission_tool_use = LanguageModelToolUse {
2032 id: "tool_id_1".into(),
2033 name: ToolRequiringPermission::name().into(),
2034 raw_input: "{}".into(),
2035 input: json!({}),
2036 is_input_complete: true,
2037 thought_signature: None,
2038 };
2039 let echo_tool_use = LanguageModelToolUse {
2040 id: "tool_id_2".into(),
2041 name: EchoTool::name().into(),
2042 raw_input: json!({"text": "test"}).to_string(),
2043 input: json!({"text": "test"}),
2044 is_input_complete: true,
2045 thought_signature: None,
2046 };
2047 fake_model.send_last_completion_stream_text_chunk("Hi!");
2048 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2049 permission_tool_use,
2050 ));
2051 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2052 echo_tool_use.clone(),
2053 ));
2054 fake_model.end_last_completion_stream();
2055 cx.run_until_parked();
2056
2057 // Ensure pending tools are skipped when building a request.
2058 let request = thread
2059 .read_with(cx, |thread, cx| {
2060 thread.build_completion_request(CompletionIntent::EditFile, cx)
2061 })
2062 .unwrap();
2063 assert_eq!(
2064 request.messages[1..],
2065 vec![
2066 LanguageModelRequestMessage {
2067 role: Role::User,
2068 content: vec!["Hey!".into()],
2069 cache: true,
2070 reasoning_details: None,
2071 },
2072 LanguageModelRequestMessage {
2073 role: Role::Assistant,
2074 content: vec![
2075 MessageContent::Text("Hi!".into()),
2076 MessageContent::ToolUse(echo_tool_use.clone())
2077 ],
2078 cache: false,
2079 reasoning_details: None,
2080 },
2081 LanguageModelRequestMessage {
2082 role: Role::User,
2083 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2084 tool_use_id: echo_tool_use.id.clone(),
2085 tool_name: echo_tool_use.name,
2086 is_error: false,
2087 content: "test".into(),
2088 output: Some("test".into())
2089 })],
2090 cache: false,
2091 reasoning_details: None,
2092 },
2093 ],
2094 );
2095}
2096
2097#[gpui::test]
2098async fn test_agent_connection(cx: &mut TestAppContext) {
2099 cx.update(settings::init);
2100 let templates = Templates::new();
2101
2102 // Initialize language model system with test provider
2103 cx.update(|cx| {
2104 gpui_tokio::init(cx);
2105
2106 let http_client = FakeHttpClient::with_404_response();
2107 let clock = Arc::new(clock::FakeSystemClock::new());
2108 let client = Client::new(clock, http_client, cx);
2109 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2110 language_model::init(client.clone(), cx);
2111 language_models::init(user_store, client.clone(), cx);
2112 LanguageModelRegistry::test(cx);
2113 });
2114 cx.executor().forbid_parking();
2115
2116 // Create a project for new_thread
2117 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2118 fake_fs.insert_tree(path!("/test"), json!({})).await;
2119 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2120 let cwd = Path::new("/test");
2121 let text_thread_store =
2122 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
2123 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
2124
2125 // Create agent and connection
2126 let agent = NativeAgent::new(
2127 project.clone(),
2128 history_store,
2129 templates.clone(),
2130 None,
2131 fake_fs.clone(),
2132 &mut cx.to_async(),
2133 )
2134 .await
2135 .unwrap();
2136 let connection = NativeAgentConnection(agent.clone());
2137
2138 // Create a thread using new_thread
2139 let connection_rc = Rc::new(connection.clone());
2140 let acp_thread = cx
2141 .update(|cx| connection_rc.new_thread(project, cwd, cx))
2142 .await
2143 .expect("new_thread should succeed");
2144
2145 // Get the session_id from the AcpThread
2146 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2147
2148 // Test model_selector returns Some
2149 let selector_opt = connection.model_selector(&session_id);
2150 assert!(
2151 selector_opt.is_some(),
2152 "agent should always support ModelSelector"
2153 );
2154 let selector = selector_opt.unwrap();
2155
2156 // Test list_models
2157 let listed_models = cx
2158 .update(|cx| selector.list_models(cx))
2159 .await
2160 .expect("list_models should succeed");
2161 let AgentModelList::Grouped(listed_models) = listed_models else {
2162 panic!("Unexpected model list type");
2163 };
2164 assert!(!listed_models.is_empty(), "should have at least one model");
2165 assert_eq!(
2166 listed_models[&AgentModelGroupName("Fake".into())][0]
2167 .id
2168 .0
2169 .as_ref(),
2170 "fake/fake"
2171 );
2172
2173 // Test selected_model returns the default
2174 let model = cx
2175 .update(|cx| selector.selected_model(cx))
2176 .await
2177 .expect("selected_model should succeed");
2178 let model = cx
2179 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2180 .unwrap();
2181 let model = model.as_fake();
2182 assert_eq!(model.id().0, "fake", "should return default model");
2183
2184 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2185 cx.run_until_parked();
2186 model.send_last_completion_stream_text_chunk("def");
2187 cx.run_until_parked();
2188 acp_thread.read_with(cx, |thread, cx| {
2189 assert_eq!(
2190 thread.to_markdown(cx),
2191 indoc! {"
2192 ## User
2193
2194 abc
2195
2196 ## Assistant
2197
2198 def
2199
2200 "}
2201 )
2202 });
2203
2204 // Test cancel
2205 cx.update(|cx| connection.cancel(&session_id, cx));
2206 request.await.expect("prompt should fail gracefully");
2207
2208 // Ensure that dropping the ACP thread causes the native thread to be
2209 // dropped as well.
2210 cx.update(|_| drop(acp_thread));
2211 let result = cx
2212 .update(|cx| {
2213 connection.prompt(
2214 Some(acp_thread::UserMessageId::new()),
2215 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2216 cx,
2217 )
2218 })
2219 .await;
2220 assert_eq!(
2221 result.as_ref().unwrap_err().to_string(),
2222 "Session not found",
2223 "unexpected result: {:?}",
2224 result
2225 );
2226}
2227
2228#[gpui::test]
2229async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2230 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2231 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2232 let fake_model = model.as_fake();
2233
2234 let mut events = thread
2235 .update(cx, |thread, cx| {
2236 thread.send(UserMessageId::new(), ["Think"], cx)
2237 })
2238 .unwrap();
2239 cx.run_until_parked();
2240
2241 // Simulate streaming partial input.
2242 let input = json!({});
2243 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2244 LanguageModelToolUse {
2245 id: "1".into(),
2246 name: ThinkingTool::name().into(),
2247 raw_input: input.to_string(),
2248 input,
2249 is_input_complete: false,
2250 thought_signature: None,
2251 },
2252 ));
2253
2254 // Input streaming completed
2255 let input = json!({ "content": "Thinking hard!" });
2256 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2257 LanguageModelToolUse {
2258 id: "1".into(),
2259 name: "thinking".into(),
2260 raw_input: input.to_string(),
2261 input,
2262 is_input_complete: true,
2263 thought_signature: None,
2264 },
2265 ));
2266 fake_model.end_last_completion_stream();
2267 cx.run_until_parked();
2268
2269 let tool_call = expect_tool_call(&mut events).await;
2270 assert_eq!(
2271 tool_call,
2272 acp::ToolCall::new("1", "Thinking")
2273 .kind(acp::ToolKind::Think)
2274 .raw_input(json!({}))
2275 .meta(acp::Meta::from_iter([(
2276 "tool_name".into(),
2277 "thinking".into()
2278 )]))
2279 );
2280 let update = expect_tool_call_update_fields(&mut events).await;
2281 assert_eq!(
2282 update,
2283 acp::ToolCallUpdate::new(
2284 "1",
2285 acp::ToolCallUpdateFields::new()
2286 .title("Thinking")
2287 .kind(acp::ToolKind::Think)
2288 .raw_input(json!({ "content": "Thinking hard!"}))
2289 )
2290 );
2291 let update = expect_tool_call_update_fields(&mut events).await;
2292 assert_eq!(
2293 update,
2294 acp::ToolCallUpdate::new(
2295 "1",
2296 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
2297 )
2298 );
2299 let update = expect_tool_call_update_fields(&mut events).await;
2300 assert_eq!(
2301 update,
2302 acp::ToolCallUpdate::new(
2303 "1",
2304 acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
2305 )
2306 );
2307 let update = expect_tool_call_update_fields(&mut events).await;
2308 assert_eq!(
2309 update,
2310 acp::ToolCallUpdate::new(
2311 "1",
2312 acp::ToolCallUpdateFields::new()
2313 .status(acp::ToolCallStatus::Completed)
2314 .raw_output("Finished thinking.")
2315 )
2316 );
2317}
2318
2319#[gpui::test]
2320async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2321 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2322 let fake_model = model.as_fake();
2323
2324 let mut events = thread
2325 .update(cx, |thread, cx| {
2326 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2327 thread.send(UserMessageId::new(), ["Hello!"], cx)
2328 })
2329 .unwrap();
2330 cx.run_until_parked();
2331
2332 fake_model.send_last_completion_stream_text_chunk("Hey!");
2333 fake_model.end_last_completion_stream();
2334
2335 let mut retry_events = Vec::new();
2336 while let Some(Ok(event)) = events.next().await {
2337 match event {
2338 ThreadEvent::Retry(retry_status) => {
2339 retry_events.push(retry_status);
2340 }
2341 ThreadEvent::Stop(..) => break,
2342 _ => {}
2343 }
2344 }
2345
2346 assert_eq!(retry_events.len(), 0);
2347 thread.read_with(cx, |thread, _cx| {
2348 assert_eq!(
2349 thread.to_markdown(),
2350 indoc! {"
2351 ## User
2352
2353 Hello!
2354
2355 ## Assistant
2356
2357 Hey!
2358 "}
2359 )
2360 });
2361}
2362
2363#[gpui::test]
2364async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2365 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2366 let fake_model = model.as_fake();
2367
2368 let mut events = thread
2369 .update(cx, |thread, cx| {
2370 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2371 thread.send(UserMessageId::new(), ["Hello!"], cx)
2372 })
2373 .unwrap();
2374 cx.run_until_parked();
2375
2376 fake_model.send_last_completion_stream_text_chunk("Hey,");
2377 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2378 provider: LanguageModelProviderName::new("Anthropic"),
2379 retry_after: Some(Duration::from_secs(3)),
2380 });
2381 fake_model.end_last_completion_stream();
2382
2383 cx.executor().advance_clock(Duration::from_secs(3));
2384 cx.run_until_parked();
2385
2386 fake_model.send_last_completion_stream_text_chunk("there!");
2387 fake_model.end_last_completion_stream();
2388 cx.run_until_parked();
2389
2390 let mut retry_events = Vec::new();
2391 while let Some(Ok(event)) = events.next().await {
2392 match event {
2393 ThreadEvent::Retry(retry_status) => {
2394 retry_events.push(retry_status);
2395 }
2396 ThreadEvent::Stop(..) => break,
2397 _ => {}
2398 }
2399 }
2400
2401 assert_eq!(retry_events.len(), 1);
2402 assert!(matches!(
2403 retry_events[0],
2404 acp_thread::RetryStatus { attempt: 1, .. }
2405 ));
2406 thread.read_with(cx, |thread, _cx| {
2407 assert_eq!(
2408 thread.to_markdown(),
2409 indoc! {"
2410 ## User
2411
2412 Hello!
2413
2414 ## Assistant
2415
2416 Hey,
2417
2418 [resume]
2419
2420 ## Assistant
2421
2422 there!
2423 "}
2424 )
2425 });
2426}
2427
2428#[gpui::test]
2429async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2430 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2431 let fake_model = model.as_fake();
2432
2433 let events = thread
2434 .update(cx, |thread, cx| {
2435 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2436 thread.add_tool(EchoTool);
2437 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2438 })
2439 .unwrap();
2440 cx.run_until_parked();
2441
2442 let tool_use_1 = LanguageModelToolUse {
2443 id: "tool_1".into(),
2444 name: EchoTool::name().into(),
2445 raw_input: json!({"text": "test"}).to_string(),
2446 input: json!({"text": "test"}),
2447 is_input_complete: true,
2448 thought_signature: None,
2449 };
2450 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2451 tool_use_1.clone(),
2452 ));
2453 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2454 provider: LanguageModelProviderName::new("Anthropic"),
2455 retry_after: Some(Duration::from_secs(3)),
2456 });
2457 fake_model.end_last_completion_stream();
2458
2459 cx.executor().advance_clock(Duration::from_secs(3));
2460 let completion = fake_model.pending_completions().pop().unwrap();
2461 assert_eq!(
2462 completion.messages[1..],
2463 vec![
2464 LanguageModelRequestMessage {
2465 role: Role::User,
2466 content: vec!["Call the echo tool!".into()],
2467 cache: false,
2468 reasoning_details: None,
2469 },
2470 LanguageModelRequestMessage {
2471 role: Role::Assistant,
2472 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2473 cache: false,
2474 reasoning_details: None,
2475 },
2476 LanguageModelRequestMessage {
2477 role: Role::User,
2478 content: vec![language_model::MessageContent::ToolResult(
2479 LanguageModelToolResult {
2480 tool_use_id: tool_use_1.id.clone(),
2481 tool_name: tool_use_1.name.clone(),
2482 is_error: false,
2483 content: "test".into(),
2484 output: Some("test".into())
2485 }
2486 )],
2487 cache: true,
2488 reasoning_details: None,
2489 },
2490 ]
2491 );
2492
2493 fake_model.send_last_completion_stream_text_chunk("Done");
2494 fake_model.end_last_completion_stream();
2495 cx.run_until_parked();
2496 events.collect::<Vec<_>>().await;
2497 thread.read_with(cx, |thread, _cx| {
2498 assert_eq!(
2499 thread.last_message(),
2500 Some(Message::Agent(AgentMessage {
2501 content: vec![AgentMessageContent::Text("Done".into())],
2502 tool_results: IndexMap::default(),
2503 reasoning_details: None,
2504 }))
2505 );
2506 })
2507}
2508
2509#[gpui::test]
2510async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2511 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2512 let fake_model = model.as_fake();
2513
2514 let mut events = thread
2515 .update(cx, |thread, cx| {
2516 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2517 thread.send(UserMessageId::new(), ["Hello!"], cx)
2518 })
2519 .unwrap();
2520 cx.run_until_parked();
2521
2522 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2523 fake_model.send_last_completion_stream_error(
2524 LanguageModelCompletionError::ServerOverloaded {
2525 provider: LanguageModelProviderName::new("Anthropic"),
2526 retry_after: Some(Duration::from_secs(3)),
2527 },
2528 );
2529 fake_model.end_last_completion_stream();
2530 cx.executor().advance_clock(Duration::from_secs(3));
2531 cx.run_until_parked();
2532 }
2533
2534 let mut errors = Vec::new();
2535 let mut retry_events = Vec::new();
2536 while let Some(event) = events.next().await {
2537 match event {
2538 Ok(ThreadEvent::Retry(retry_status)) => {
2539 retry_events.push(retry_status);
2540 }
2541 Ok(ThreadEvent::Stop(..)) => break,
2542 Err(error) => errors.push(error),
2543 _ => {}
2544 }
2545 }
2546
2547 assert_eq!(
2548 retry_events.len(),
2549 crate::thread::MAX_RETRY_ATTEMPTS as usize
2550 );
2551 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2552 assert_eq!(retry_events[i].attempt, i + 1);
2553 }
2554 assert_eq!(errors.len(), 1);
2555 let error = errors[0]
2556 .downcast_ref::<LanguageModelCompletionError>()
2557 .unwrap();
2558 assert!(matches!(
2559 error,
2560 LanguageModelCompletionError::ServerOverloaded { .. }
2561 ));
2562}
2563
2564/// Filters out the stop events for asserting against in tests
2565fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2566 result_events
2567 .into_iter()
2568 .filter_map(|event| match event.unwrap() {
2569 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2570 _ => None,
2571 })
2572 .collect()
2573}
2574
2575struct ThreadTest {
2576 model: Arc<dyn LanguageModel>,
2577 thread: Entity<Thread>,
2578 project_context: Entity<ProjectContext>,
2579 context_server_store: Entity<ContextServerStore>,
2580 fs: Arc<FakeFs>,
2581}
2582
2583enum TestModel {
2584 Sonnet4,
2585 Fake,
2586}
2587
2588impl TestModel {
2589 fn id(&self) -> LanguageModelId {
2590 match self {
2591 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2592 TestModel::Fake => unreachable!(),
2593 }
2594 }
2595}
2596
2597async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2598 cx.executor().allow_parking();
2599
2600 let fs = FakeFs::new(cx.background_executor.clone());
2601 fs.create_dir(paths::settings_file().parent().unwrap())
2602 .await
2603 .unwrap();
2604 fs.insert_file(
2605 paths::settings_file(),
2606 json!({
2607 "agent": {
2608 "default_profile": "test-profile",
2609 "profiles": {
2610 "test-profile": {
2611 "name": "Test Profile",
2612 "tools": {
2613 EchoTool::name(): true,
2614 DelayTool::name(): true,
2615 WordListTool::name(): true,
2616 ToolRequiringPermission::name(): true,
2617 InfiniteTool::name(): true,
2618 ThinkingTool::name(): true,
2619 }
2620 }
2621 }
2622 }
2623 })
2624 .to_string()
2625 .into_bytes(),
2626 )
2627 .await;
2628
2629 cx.update(|cx| {
2630 settings::init(cx);
2631
2632 match model {
2633 TestModel::Fake => {}
2634 TestModel::Sonnet4 => {
2635 gpui_tokio::init(cx);
2636 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2637 cx.set_http_client(Arc::new(http_client));
2638 let client = Client::production(cx);
2639 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2640 language_model::init(client.clone(), cx);
2641 language_models::init(user_store, client.clone(), cx);
2642 }
2643 };
2644
2645 watch_settings(fs.clone(), cx);
2646 });
2647
2648 let templates = Templates::new();
2649
2650 fs.insert_tree(path!("/test"), json!({})).await;
2651 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2652
2653 let model = cx
2654 .update(|cx| {
2655 if let TestModel::Fake = model {
2656 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2657 } else {
2658 let model_id = model.id();
2659 let models = LanguageModelRegistry::read_global(cx);
2660 let model = models
2661 .available_models(cx)
2662 .find(|model| model.id() == model_id)
2663 .unwrap();
2664
2665 let provider = models.provider(&model.provider_id()).unwrap();
2666 let authenticated = provider.authenticate(cx);
2667
2668 cx.spawn(async move |_cx| {
2669 authenticated.await.unwrap();
2670 model
2671 })
2672 }
2673 })
2674 .await;
2675
2676 let project_context = cx.new(|_cx| ProjectContext::default());
2677 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2678 let context_server_registry =
2679 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2680 let thread = cx.new(|cx| {
2681 Thread::new(
2682 project,
2683 project_context.clone(),
2684 context_server_registry,
2685 templates,
2686 Some(model.clone()),
2687 cx,
2688 )
2689 });
2690 ThreadTest {
2691 model,
2692 thread,
2693 project_context,
2694 context_server_store,
2695 fs,
2696 }
2697}
2698
2699#[cfg(test)]
2700#[ctor::ctor]
2701fn init_logger() {
2702 if std::env::var("RUST_LOG").is_ok() {
2703 env_logger::init();
2704 }
2705}
2706
2707fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2708 let fs = fs.clone();
2709 cx.spawn({
2710 async move |cx| {
2711 let mut new_settings_content_rx = settings::watch_config_file(
2712 cx.background_executor(),
2713 fs,
2714 paths::settings_file().clone(),
2715 );
2716
2717 while let Some(new_settings_content) = new_settings_content_rx.next().await {
2718 cx.update(|cx| {
2719 SettingsStore::update_global(cx, |settings, cx| {
2720 settings.set_user_settings(&new_settings_content, cx)
2721 })
2722 })
2723 .ok();
2724 }
2725 }
2726 })
2727 .detach();
2728}
2729
2730fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2731 completion
2732 .tools
2733 .iter()
2734 .map(|tool| tool.name.clone())
2735 .collect()
2736}
2737
2738fn setup_context_server(
2739 name: &'static str,
2740 tools: Vec<context_server::types::Tool>,
2741 context_server_store: &Entity<ContextServerStore>,
2742 cx: &mut TestAppContext,
2743) -> mpsc::UnboundedReceiver<(
2744 context_server::types::CallToolParams,
2745 oneshot::Sender<context_server::types::CallToolResponse>,
2746)> {
2747 cx.update(|cx| {
2748 let mut settings = ProjectSettings::get_global(cx).clone();
2749 settings.context_servers.insert(
2750 name.into(),
2751 project::project_settings::ContextServerSettings::Stdio {
2752 enabled: true,
2753 command: ContextServerCommand {
2754 path: "somebinary".into(),
2755 args: Vec::new(),
2756 env: None,
2757 timeout: None,
2758 },
2759 },
2760 );
2761 ProjectSettings::override_global(settings, cx);
2762 });
2763
2764 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2765 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2766 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2767 context_server::types::InitializeResponse {
2768 protocol_version: context_server::types::ProtocolVersion(
2769 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2770 ),
2771 server_info: context_server::types::Implementation {
2772 name: name.into(),
2773 version: "1.0.0".to_string(),
2774 },
2775 capabilities: context_server::types::ServerCapabilities {
2776 tools: Some(context_server::types::ToolsCapabilities {
2777 list_changed: Some(true),
2778 }),
2779 ..Default::default()
2780 },
2781 meta: None,
2782 }
2783 })
2784 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2785 let tools = tools.clone();
2786 async move {
2787 context_server::types::ListToolsResponse {
2788 tools,
2789 next_cursor: None,
2790 meta: None,
2791 }
2792 }
2793 })
2794 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2795 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2796 async move {
2797 let (response_tx, response_rx) = oneshot::channel();
2798 mcp_tool_calls_tx
2799 .unbounded_send((params, response_tx))
2800 .unwrap();
2801 response_rx.await.unwrap()
2802 }
2803 });
2804 context_server_store.update(cx, |store, cx| {
2805 store.start_server(
2806 Arc::new(ContextServer::new(
2807 ContextServerId(name.into()),
2808 Arc::new(fake_transport),
2809 )),
2810 cx,
2811 );
2812 });
2813 cx.run_until_parked();
2814 mcp_tool_calls_rx
2815}
2816
2817#[gpui::test]
2818async fn test_tokens_before_message(cx: &mut TestAppContext) {
2819 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2820 let fake_model = model.as_fake();
2821
2822 // First message
2823 let message_1_id = UserMessageId::new();
2824 thread
2825 .update(cx, |thread, cx| {
2826 thread.send(message_1_id.clone(), ["First message"], cx)
2827 })
2828 .unwrap();
2829 cx.run_until_parked();
2830
2831 // Before any response, tokens_before_message should return None for first message
2832 thread.read_with(cx, |thread, _| {
2833 assert_eq!(
2834 thread.tokens_before_message(&message_1_id),
2835 None,
2836 "First message should have no tokens before it"
2837 );
2838 });
2839
2840 // Complete first message with usage
2841 fake_model.send_last_completion_stream_text_chunk("Response 1");
2842 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2843 language_model::TokenUsage {
2844 input_tokens: 100,
2845 output_tokens: 50,
2846 cache_creation_input_tokens: 0,
2847 cache_read_input_tokens: 0,
2848 },
2849 ));
2850 fake_model.end_last_completion_stream();
2851 cx.run_until_parked();
2852
2853 // First message still has no tokens before it
2854 thread.read_with(cx, |thread, _| {
2855 assert_eq!(
2856 thread.tokens_before_message(&message_1_id),
2857 None,
2858 "First message should still have no tokens before it after response"
2859 );
2860 });
2861
2862 // Second message
2863 let message_2_id = UserMessageId::new();
2864 thread
2865 .update(cx, |thread, cx| {
2866 thread.send(message_2_id.clone(), ["Second message"], cx)
2867 })
2868 .unwrap();
2869 cx.run_until_parked();
2870
2871 // Second message should have first message's input tokens before it
2872 thread.read_with(cx, |thread, _| {
2873 assert_eq!(
2874 thread.tokens_before_message(&message_2_id),
2875 Some(100),
2876 "Second message should have 100 tokens before it (from first request)"
2877 );
2878 });
2879
2880 // Complete second message
2881 fake_model.send_last_completion_stream_text_chunk("Response 2");
2882 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2883 language_model::TokenUsage {
2884 input_tokens: 250, // Total for this request (includes previous context)
2885 output_tokens: 75,
2886 cache_creation_input_tokens: 0,
2887 cache_read_input_tokens: 0,
2888 },
2889 ));
2890 fake_model.end_last_completion_stream();
2891 cx.run_until_parked();
2892
2893 // Third message
2894 let message_3_id = UserMessageId::new();
2895 thread
2896 .update(cx, |thread, cx| {
2897 thread.send(message_3_id.clone(), ["Third message"], cx)
2898 })
2899 .unwrap();
2900 cx.run_until_parked();
2901
2902 // Third message should have second message's input tokens (250) before it
2903 thread.read_with(cx, |thread, _| {
2904 assert_eq!(
2905 thread.tokens_before_message(&message_3_id),
2906 Some(250),
2907 "Third message should have 250 tokens before it (from second request)"
2908 );
2909 // Second message should still have 100
2910 assert_eq!(
2911 thread.tokens_before_message(&message_2_id),
2912 Some(100),
2913 "Second message should still have 100 tokens before it"
2914 );
2915 // First message still has none
2916 assert_eq!(
2917 thread.tokens_before_message(&message_1_id),
2918 None,
2919 "First message should still have no tokens before it"
2920 );
2921 });
2922}
2923
2924#[gpui::test]
2925async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
2926 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2927 let fake_model = model.as_fake();
2928
2929 // Set up three messages with responses
2930 let message_1_id = UserMessageId::new();
2931 thread
2932 .update(cx, |thread, cx| {
2933 thread.send(message_1_id.clone(), ["Message 1"], cx)
2934 })
2935 .unwrap();
2936 cx.run_until_parked();
2937 fake_model.send_last_completion_stream_text_chunk("Response 1");
2938 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2939 language_model::TokenUsage {
2940 input_tokens: 100,
2941 output_tokens: 50,
2942 cache_creation_input_tokens: 0,
2943 cache_read_input_tokens: 0,
2944 },
2945 ));
2946 fake_model.end_last_completion_stream();
2947 cx.run_until_parked();
2948
2949 let message_2_id = UserMessageId::new();
2950 thread
2951 .update(cx, |thread, cx| {
2952 thread.send(message_2_id.clone(), ["Message 2"], cx)
2953 })
2954 .unwrap();
2955 cx.run_until_parked();
2956 fake_model.send_last_completion_stream_text_chunk("Response 2");
2957 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2958 language_model::TokenUsage {
2959 input_tokens: 250,
2960 output_tokens: 75,
2961 cache_creation_input_tokens: 0,
2962 cache_read_input_tokens: 0,
2963 },
2964 ));
2965 fake_model.end_last_completion_stream();
2966 cx.run_until_parked();
2967
2968 // Verify initial state
2969 thread.read_with(cx, |thread, _| {
2970 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
2971 });
2972
2973 // Truncate at message 2 (removes message 2 and everything after)
2974 thread
2975 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
2976 .unwrap();
2977 cx.run_until_parked();
2978
2979 // After truncation, message_2_id no longer exists, so lookup should return None
2980 thread.read_with(cx, |thread, _| {
2981 assert_eq!(
2982 thread.tokens_before_message(&message_2_id),
2983 None,
2984 "After truncation, message 2 no longer exists"
2985 );
2986 // Message 1 still exists but has no tokens before it
2987 assert_eq!(
2988 thread.tokens_before_message(&message_1_id),
2989 None,
2990 "First message still has no tokens before it"
2991 );
2992 });
2993}