1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use fs::{FakeFs, Fs};
11use futures::{
12 StreamExt,
13 channel::{
14 mpsc::{self, UnboundedReceiver},
15 oneshot,
16 },
17};
18use gpui::{
19 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
20};
21use indoc::indoc;
22use language_model::{
23 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
24 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
25 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
26 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
27};
28use pretty_assertions::assert_eq;
29use project::{
30 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
31};
32use prompt_store::ProjectContext;
33use reqwest_client::ReqwestClient;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use serde_json::json;
37use settings::{Settings, SettingsStore};
38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
39use util::path;
40
41mod test_tools;
42use test_tools::*;
43
44#[gpui::test]
45async fn test_echo(cx: &mut TestAppContext) {
46 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
47 let fake_model = model.as_fake();
48
49 let events = thread
50 .update(cx, |thread, cx| {
51 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
52 })
53 .unwrap();
54 cx.run_until_parked();
55 fake_model.send_last_completion_stream_text_chunk("Hello");
56 fake_model
57 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
58 fake_model.end_last_completion_stream();
59
60 let events = events.collect().await;
61 thread.update(cx, |thread, _cx| {
62 assert_eq!(
63 thread.last_message().unwrap().to_markdown(),
64 indoc! {"
65 ## Assistant
66
67 Hello
68 "}
69 )
70 });
71 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
72}
73
74#[gpui::test]
75async fn test_thinking(cx: &mut TestAppContext) {
76 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
77 let fake_model = model.as_fake();
78
79 let events = thread
80 .update(cx, |thread, cx| {
81 thread.send(
82 UserMessageId::new(),
83 [indoc! {"
84 Testing:
85
86 Generate a thinking step where you just think the word 'Think',
87 and have your final answer be 'Hello'
88 "}],
89 cx,
90 )
91 })
92 .unwrap();
93 cx.run_until_parked();
94 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
95 text: "Think".to_string(),
96 signature: None,
97 });
98 fake_model.send_last_completion_stream_text_chunk("Hello");
99 fake_model
100 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
101 fake_model.end_last_completion_stream();
102
103 let events = events.collect().await;
104 thread.update(cx, |thread, _cx| {
105 assert_eq!(
106 thread.last_message().unwrap().to_markdown(),
107 indoc! {"
108 ## Assistant
109
110 <think>Think</think>
111 Hello
112 "}
113 )
114 });
115 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
116}
117
118#[gpui::test]
119async fn test_system_prompt(cx: &mut TestAppContext) {
120 let ThreadTest {
121 model,
122 thread,
123 project_context,
124 ..
125 } = setup(cx, TestModel::Fake).await;
126 let fake_model = model.as_fake();
127
128 project_context.update(cx, |project_context, _cx| {
129 project_context.shell = "test-shell".into()
130 });
131 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
132 thread
133 .update(cx, |thread, cx| {
134 thread.send(UserMessageId::new(), ["abc"], cx)
135 })
136 .unwrap();
137 cx.run_until_parked();
138 let mut pending_completions = fake_model.pending_completions();
139 assert_eq!(
140 pending_completions.len(),
141 1,
142 "unexpected pending completions: {:?}",
143 pending_completions
144 );
145
146 let pending_completion = pending_completions.pop().unwrap();
147 assert_eq!(pending_completion.messages[0].role, Role::System);
148
149 let system_message = &pending_completion.messages[0];
150 let system_prompt = system_message.content[0].to_str().unwrap();
151 assert!(
152 system_prompt.contains("test-shell"),
153 "unexpected system message: {:?}",
154 system_message
155 );
156 assert!(
157 system_prompt.contains("## Fixing Diagnostics"),
158 "unexpected system message: {:?}",
159 system_message
160 );
161}
162
163#[gpui::test]
164async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
165 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
166 let fake_model = model.as_fake();
167
168 thread
169 .update(cx, |thread, cx| {
170 thread.send(UserMessageId::new(), ["abc"], cx)
171 })
172 .unwrap();
173 cx.run_until_parked();
174 let mut pending_completions = fake_model.pending_completions();
175 assert_eq!(
176 pending_completions.len(),
177 1,
178 "unexpected pending completions: {:?}",
179 pending_completions
180 );
181
182 let pending_completion = pending_completions.pop().unwrap();
183 assert_eq!(pending_completion.messages[0].role, Role::System);
184
185 let system_message = &pending_completion.messages[0];
186 let system_prompt = system_message.content[0].to_str().unwrap();
187 assert!(
188 !system_prompt.contains("## Tool Use"),
189 "unexpected system message: {:?}",
190 system_message
191 );
192 assert!(
193 !system_prompt.contains("## Fixing Diagnostics"),
194 "unexpected system message: {:?}",
195 system_message
196 );
197}
198
199#[gpui::test]
200async fn test_prompt_caching(cx: &mut TestAppContext) {
201 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
202 let fake_model = model.as_fake();
203
204 // Send initial user message and verify it's cached
205 thread
206 .update(cx, |thread, cx| {
207 thread.send(UserMessageId::new(), ["Message 1"], cx)
208 })
209 .unwrap();
210 cx.run_until_parked();
211
212 let completion = fake_model.pending_completions().pop().unwrap();
213 assert_eq!(
214 completion.messages[1..],
215 vec![LanguageModelRequestMessage {
216 role: Role::User,
217 content: vec!["Message 1".into()],
218 cache: true,
219 reasoning_details: None,
220 }]
221 );
222 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
223 "Response to Message 1".into(),
224 ));
225 fake_model.end_last_completion_stream();
226 cx.run_until_parked();
227
228 // Send another user message and verify only the latest is cached
229 thread
230 .update(cx, |thread, cx| {
231 thread.send(UserMessageId::new(), ["Message 2"], cx)
232 })
233 .unwrap();
234 cx.run_until_parked();
235
236 let completion = fake_model.pending_completions().pop().unwrap();
237 assert_eq!(
238 completion.messages[1..],
239 vec![
240 LanguageModelRequestMessage {
241 role: Role::User,
242 content: vec!["Message 1".into()],
243 cache: false,
244 reasoning_details: None,
245 },
246 LanguageModelRequestMessage {
247 role: Role::Assistant,
248 content: vec!["Response to Message 1".into()],
249 cache: false,
250 reasoning_details: None,
251 },
252 LanguageModelRequestMessage {
253 role: Role::User,
254 content: vec!["Message 2".into()],
255 cache: true,
256 reasoning_details: None,
257 }
258 ]
259 );
260 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
261 "Response to Message 2".into(),
262 ));
263 fake_model.end_last_completion_stream();
264 cx.run_until_parked();
265
266 // Simulate a tool call and verify that the latest tool result is cached
267 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
268 thread
269 .update(cx, |thread, cx| {
270 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
271 })
272 .unwrap();
273 cx.run_until_parked();
274
275 let tool_use = LanguageModelToolUse {
276 id: "tool_1".into(),
277 name: EchoTool::name().into(),
278 raw_input: json!({"text": "test"}).to_string(),
279 input: json!({"text": "test"}),
280 is_input_complete: true,
281 thought_signature: None,
282 };
283 fake_model
284 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
285 fake_model.end_last_completion_stream();
286 cx.run_until_parked();
287
288 let completion = fake_model.pending_completions().pop().unwrap();
289 let tool_result = LanguageModelToolResult {
290 tool_use_id: "tool_1".into(),
291 tool_name: EchoTool::name().into(),
292 is_error: false,
293 content: "test".into(),
294 output: Some("test".into()),
295 };
296 assert_eq!(
297 completion.messages[1..],
298 vec![
299 LanguageModelRequestMessage {
300 role: Role::User,
301 content: vec!["Message 1".into()],
302 cache: false,
303 reasoning_details: None,
304 },
305 LanguageModelRequestMessage {
306 role: Role::Assistant,
307 content: vec!["Response to Message 1".into()],
308 cache: false,
309 reasoning_details: None,
310 },
311 LanguageModelRequestMessage {
312 role: Role::User,
313 content: vec!["Message 2".into()],
314 cache: false,
315 reasoning_details: None,
316 },
317 LanguageModelRequestMessage {
318 role: Role::Assistant,
319 content: vec!["Response to Message 2".into()],
320 cache: false,
321 reasoning_details: None,
322 },
323 LanguageModelRequestMessage {
324 role: Role::User,
325 content: vec!["Use the echo tool".into()],
326 cache: false,
327 reasoning_details: None,
328 },
329 LanguageModelRequestMessage {
330 role: Role::Assistant,
331 content: vec![MessageContent::ToolUse(tool_use)],
332 cache: false,
333 reasoning_details: None,
334 },
335 LanguageModelRequestMessage {
336 role: Role::User,
337 content: vec![MessageContent::ToolResult(tool_result)],
338 cache: true,
339 reasoning_details: None,
340 }
341 ]
342 );
343}
344
345#[gpui::test]
346#[cfg_attr(not(feature = "e2e"), ignore)]
347async fn test_basic_tool_calls(cx: &mut TestAppContext) {
348 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
349
350 // Test a tool call that's likely to complete *before* streaming stops.
351 let events = thread
352 .update(cx, |thread, cx| {
353 thread.add_tool(EchoTool);
354 thread.send(
355 UserMessageId::new(),
356 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
357 cx,
358 )
359 })
360 .unwrap()
361 .collect()
362 .await;
363 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
364
365 // Test a tool calls that's likely to complete *after* streaming stops.
366 let events = thread
367 .update(cx, |thread, cx| {
368 thread.remove_tool(&EchoTool::name());
369 thread.add_tool(DelayTool);
370 thread.send(
371 UserMessageId::new(),
372 [
373 "Now call the delay tool with 200ms.",
374 "When the timer goes off, then you echo the output of the tool.",
375 ],
376 cx,
377 )
378 })
379 .unwrap()
380 .collect()
381 .await;
382 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
383 thread.update(cx, |thread, _cx| {
384 assert!(
385 thread
386 .last_message()
387 .unwrap()
388 .as_agent_message()
389 .unwrap()
390 .content
391 .iter()
392 .any(|content| {
393 if let AgentMessageContent::Text(text) = content {
394 text.contains("Ding")
395 } else {
396 false
397 }
398 }),
399 "{}",
400 thread.to_markdown()
401 );
402 });
403}
404
405#[gpui::test]
406#[cfg_attr(not(feature = "e2e"), ignore)]
407async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
408 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
409
410 // Test a tool call that's likely to complete *before* streaming stops.
411 let mut events = thread
412 .update(cx, |thread, cx| {
413 thread.add_tool(WordListTool);
414 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
415 })
416 .unwrap();
417
418 let mut saw_partial_tool_use = false;
419 while let Some(event) = events.next().await {
420 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
421 thread.update(cx, |thread, _cx| {
422 // Look for a tool use in the thread's last message
423 let message = thread.last_message().unwrap();
424 let agent_message = message.as_agent_message().unwrap();
425 let last_content = agent_message.content.last().unwrap();
426 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
427 assert_eq!(last_tool_use.name.as_ref(), "word_list");
428 if tool_call.status == acp::ToolCallStatus::Pending {
429 if !last_tool_use.is_input_complete
430 && last_tool_use.input.get("g").is_none()
431 {
432 saw_partial_tool_use = true;
433 }
434 } else {
435 last_tool_use
436 .input
437 .get("a")
438 .expect("'a' has streamed because input is now complete");
439 last_tool_use
440 .input
441 .get("g")
442 .expect("'g' has streamed because input is now complete");
443 }
444 } else {
445 panic!("last content should be a tool use");
446 }
447 });
448 }
449 }
450
451 assert!(
452 saw_partial_tool_use,
453 "should see at least one partially streamed tool use in the history"
454 );
455}
456
457#[gpui::test]
458async fn test_tool_authorization(cx: &mut TestAppContext) {
459 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
460 let fake_model = model.as_fake();
461
462 let mut events = thread
463 .update(cx, |thread, cx| {
464 thread.add_tool(ToolRequiringPermission);
465 thread.send(UserMessageId::new(), ["abc"], cx)
466 })
467 .unwrap();
468 cx.run_until_parked();
469 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
470 LanguageModelToolUse {
471 id: "tool_id_1".into(),
472 name: ToolRequiringPermission::name().into(),
473 raw_input: "{}".into(),
474 input: json!({}),
475 is_input_complete: true,
476 thought_signature: None,
477 },
478 ));
479 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
480 LanguageModelToolUse {
481 id: "tool_id_2".into(),
482 name: ToolRequiringPermission::name().into(),
483 raw_input: "{}".into(),
484 input: json!({}),
485 is_input_complete: true,
486 thought_signature: None,
487 },
488 ));
489 fake_model.end_last_completion_stream();
490 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
491 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
492
493 // Approve the first
494 tool_call_auth_1
495 .response
496 .send(tool_call_auth_1.options[1].id.clone())
497 .unwrap();
498 cx.run_until_parked();
499
500 // Reject the second
501 tool_call_auth_2
502 .response
503 .send(tool_call_auth_1.options[2].id.clone())
504 .unwrap();
505 cx.run_until_parked();
506
507 let completion = fake_model.pending_completions().pop().unwrap();
508 let message = completion.messages.last().unwrap();
509 assert_eq!(
510 message.content,
511 vec![
512 language_model::MessageContent::ToolResult(LanguageModelToolResult {
513 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
514 tool_name: ToolRequiringPermission::name().into(),
515 is_error: false,
516 content: "Allowed".into(),
517 output: Some("Allowed".into())
518 }),
519 language_model::MessageContent::ToolResult(LanguageModelToolResult {
520 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
521 tool_name: ToolRequiringPermission::name().into(),
522 is_error: true,
523 content: "Permission to run tool denied by user".into(),
524 output: Some("Permission to run tool denied by user".into())
525 })
526 ]
527 );
528
529 // Simulate yet another tool call.
530 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
531 LanguageModelToolUse {
532 id: "tool_id_3".into(),
533 name: ToolRequiringPermission::name().into(),
534 raw_input: "{}".into(),
535 input: json!({}),
536 is_input_complete: true,
537 thought_signature: None,
538 },
539 ));
540 fake_model.end_last_completion_stream();
541
542 // Respond by always allowing tools.
543 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
544 tool_call_auth_3
545 .response
546 .send(tool_call_auth_3.options[0].id.clone())
547 .unwrap();
548 cx.run_until_parked();
549 let completion = fake_model.pending_completions().pop().unwrap();
550 let message = completion.messages.last().unwrap();
551 assert_eq!(
552 message.content,
553 vec![language_model::MessageContent::ToolResult(
554 LanguageModelToolResult {
555 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
556 tool_name: ToolRequiringPermission::name().into(),
557 is_error: false,
558 content: "Allowed".into(),
559 output: Some("Allowed".into())
560 }
561 )]
562 );
563
564 // Simulate a final tool call, ensuring we don't trigger authorization.
565 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
566 LanguageModelToolUse {
567 id: "tool_id_4".into(),
568 name: ToolRequiringPermission::name().into(),
569 raw_input: "{}".into(),
570 input: json!({}),
571 is_input_complete: true,
572 thought_signature: None,
573 },
574 ));
575 fake_model.end_last_completion_stream();
576 cx.run_until_parked();
577 let completion = fake_model.pending_completions().pop().unwrap();
578 let message = completion.messages.last().unwrap();
579 assert_eq!(
580 message.content,
581 vec![language_model::MessageContent::ToolResult(
582 LanguageModelToolResult {
583 tool_use_id: "tool_id_4".into(),
584 tool_name: ToolRequiringPermission::name().into(),
585 is_error: false,
586 content: "Allowed".into(),
587 output: Some("Allowed".into())
588 }
589 )]
590 );
591}
592
593#[gpui::test]
594async fn test_tool_hallucination(cx: &mut TestAppContext) {
595 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
596 let fake_model = model.as_fake();
597
598 let mut events = thread
599 .update(cx, |thread, cx| {
600 thread.send(UserMessageId::new(), ["abc"], cx)
601 })
602 .unwrap();
603 cx.run_until_parked();
604 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
605 LanguageModelToolUse {
606 id: "tool_id_1".into(),
607 name: "nonexistent_tool".into(),
608 raw_input: "{}".into(),
609 input: json!({}),
610 is_input_complete: true,
611 thought_signature: None,
612 },
613 ));
614 fake_model.end_last_completion_stream();
615
616 let tool_call = expect_tool_call(&mut events).await;
617 assert_eq!(tool_call.title, "nonexistent_tool");
618 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
619 let update = expect_tool_call_update_fields(&mut events).await;
620 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
621}
622
623#[gpui::test]
624async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
625 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
626 let fake_model = model.as_fake();
627
628 let events = thread
629 .update(cx, |thread, cx| {
630 thread.add_tool(EchoTool);
631 thread.send(UserMessageId::new(), ["abc"], cx)
632 })
633 .unwrap();
634 cx.run_until_parked();
635 let tool_use = LanguageModelToolUse {
636 id: "tool_id_1".into(),
637 name: EchoTool::name().into(),
638 raw_input: "{}".into(),
639 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
640 is_input_complete: true,
641 thought_signature: None,
642 };
643 fake_model
644 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
645 fake_model.end_last_completion_stream();
646
647 cx.run_until_parked();
648 let completion = fake_model.pending_completions().pop().unwrap();
649 let tool_result = LanguageModelToolResult {
650 tool_use_id: "tool_id_1".into(),
651 tool_name: EchoTool::name().into(),
652 is_error: false,
653 content: "def".into(),
654 output: Some("def".into()),
655 };
656 assert_eq!(
657 completion.messages[1..],
658 vec![
659 LanguageModelRequestMessage {
660 role: Role::User,
661 content: vec!["abc".into()],
662 cache: false,
663 reasoning_details: None,
664 },
665 LanguageModelRequestMessage {
666 role: Role::Assistant,
667 content: vec![MessageContent::ToolUse(tool_use.clone())],
668 cache: false,
669 reasoning_details: None,
670 },
671 LanguageModelRequestMessage {
672 role: Role::User,
673 content: vec![MessageContent::ToolResult(tool_result.clone())],
674 cache: true,
675 reasoning_details: None,
676 },
677 ]
678 );
679
680 // Simulate reaching tool use limit.
681 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
682 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
683 ));
684 fake_model.end_last_completion_stream();
685 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
686 assert!(
687 last_event
688 .unwrap_err()
689 .is::<language_model::ToolUseLimitReachedError>()
690 );
691
692 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
693 cx.run_until_parked();
694 let completion = fake_model.pending_completions().pop().unwrap();
695 assert_eq!(
696 completion.messages[1..],
697 vec![
698 LanguageModelRequestMessage {
699 role: Role::User,
700 content: vec!["abc".into()],
701 cache: false,
702 reasoning_details: None,
703 },
704 LanguageModelRequestMessage {
705 role: Role::Assistant,
706 content: vec![MessageContent::ToolUse(tool_use)],
707 cache: false,
708 reasoning_details: None,
709 },
710 LanguageModelRequestMessage {
711 role: Role::User,
712 content: vec![MessageContent::ToolResult(tool_result)],
713 cache: false,
714 reasoning_details: None,
715 },
716 LanguageModelRequestMessage {
717 role: Role::User,
718 content: vec!["Continue where you left off".into()],
719 cache: true,
720 reasoning_details: None,
721 }
722 ]
723 );
724
725 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
726 fake_model.end_last_completion_stream();
727 events.collect::<Vec<_>>().await;
728 thread.read_with(cx, |thread, _cx| {
729 assert_eq!(
730 thread.last_message().unwrap().to_markdown(),
731 indoc! {"
732 ## Assistant
733
734 Done
735 "}
736 )
737 });
738}
739
740#[gpui::test]
741async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
742 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
743 let fake_model = model.as_fake();
744
745 let events = thread
746 .update(cx, |thread, cx| {
747 thread.add_tool(EchoTool);
748 thread.send(UserMessageId::new(), ["abc"], cx)
749 })
750 .unwrap();
751 cx.run_until_parked();
752
753 let tool_use = LanguageModelToolUse {
754 id: "tool_id_1".into(),
755 name: EchoTool::name().into(),
756 raw_input: "{}".into(),
757 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
758 is_input_complete: true,
759 thought_signature: None,
760 };
761 let tool_result = LanguageModelToolResult {
762 tool_use_id: "tool_id_1".into(),
763 tool_name: EchoTool::name().into(),
764 is_error: false,
765 content: "def".into(),
766 output: Some("def".into()),
767 };
768 fake_model
769 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
770 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
771 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
772 ));
773 fake_model.end_last_completion_stream();
774 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
775 assert!(
776 last_event
777 .unwrap_err()
778 .is::<language_model::ToolUseLimitReachedError>()
779 );
780
781 thread
782 .update(cx, |thread, cx| {
783 thread.send(UserMessageId::new(), vec!["ghi"], cx)
784 })
785 .unwrap();
786 cx.run_until_parked();
787 let completion = fake_model.pending_completions().pop().unwrap();
788 assert_eq!(
789 completion.messages[1..],
790 vec![
791 LanguageModelRequestMessage {
792 role: Role::User,
793 content: vec!["abc".into()],
794 cache: false,
795 reasoning_details: None,
796 },
797 LanguageModelRequestMessage {
798 role: Role::Assistant,
799 content: vec![MessageContent::ToolUse(tool_use)],
800 cache: false,
801 reasoning_details: None,
802 },
803 LanguageModelRequestMessage {
804 role: Role::User,
805 content: vec![MessageContent::ToolResult(tool_result)],
806 cache: false,
807 reasoning_details: None,
808 },
809 LanguageModelRequestMessage {
810 role: Role::User,
811 content: vec!["ghi".into()],
812 cache: true,
813 reasoning_details: None,
814 }
815 ]
816 );
817}
818
819async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
820 let event = events
821 .next()
822 .await
823 .expect("no tool call authorization event received")
824 .unwrap();
825 match event {
826 ThreadEvent::ToolCall(tool_call) => tool_call,
827 event => {
828 panic!("Unexpected event {event:?}");
829 }
830 }
831}
832
833async fn expect_tool_call_update_fields(
834 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
835) -> acp::ToolCallUpdate {
836 let event = events
837 .next()
838 .await
839 .expect("no tool call authorization event received")
840 .unwrap();
841 match event {
842 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
843 event => {
844 panic!("Unexpected event {event:?}");
845 }
846 }
847}
848
849async fn next_tool_call_authorization(
850 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
851) -> ToolCallAuthorization {
852 loop {
853 let event = events
854 .next()
855 .await
856 .expect("no tool call authorization event received")
857 .unwrap();
858 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
859 let permission_kinds = tool_call_authorization
860 .options
861 .iter()
862 .map(|o| o.kind)
863 .collect::<Vec<_>>();
864 assert_eq!(
865 permission_kinds,
866 vec![
867 acp::PermissionOptionKind::AllowAlways,
868 acp::PermissionOptionKind::AllowOnce,
869 acp::PermissionOptionKind::RejectOnce,
870 ]
871 );
872 return tool_call_authorization;
873 }
874 }
875}
876
877#[gpui::test]
878#[cfg_attr(not(feature = "e2e"), ignore)]
879async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
880 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
881
882 // Test concurrent tool calls with different delay times
883 let events = thread
884 .update(cx, |thread, cx| {
885 thread.add_tool(DelayTool);
886 thread.send(
887 UserMessageId::new(),
888 [
889 "Call the delay tool twice in the same message.",
890 "Once with 100ms. Once with 300ms.",
891 "When both timers are complete, describe the outputs.",
892 ],
893 cx,
894 )
895 })
896 .unwrap()
897 .collect()
898 .await;
899
900 let stop_reasons = stop_events(events);
901 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
902
903 thread.update(cx, |thread, _cx| {
904 let last_message = thread.last_message().unwrap();
905 let agent_message = last_message.as_agent_message().unwrap();
906 let text = agent_message
907 .content
908 .iter()
909 .filter_map(|content| {
910 if let AgentMessageContent::Text(text) = content {
911 Some(text.as_str())
912 } else {
913 None
914 }
915 })
916 .collect::<String>();
917
918 assert!(text.contains("Ding"));
919 });
920}
921
922#[gpui::test]
923async fn test_profiles(cx: &mut TestAppContext) {
924 let ThreadTest {
925 model, thread, fs, ..
926 } = setup(cx, TestModel::Fake).await;
927 let fake_model = model.as_fake();
928
929 thread.update(cx, |thread, _cx| {
930 thread.add_tool(DelayTool);
931 thread.add_tool(EchoTool);
932 thread.add_tool(InfiniteTool);
933 });
934
935 // Override profiles and wait for settings to be loaded.
936 fs.insert_file(
937 paths::settings_file(),
938 json!({
939 "agent": {
940 "profiles": {
941 "test-1": {
942 "name": "Test Profile 1",
943 "tools": {
944 EchoTool::name(): true,
945 DelayTool::name(): true,
946 }
947 },
948 "test-2": {
949 "name": "Test Profile 2",
950 "tools": {
951 InfiniteTool::name(): true,
952 }
953 }
954 }
955 }
956 })
957 .to_string()
958 .into_bytes(),
959 )
960 .await;
961 cx.run_until_parked();
962
963 // Test that test-1 profile (default) has echo and delay tools
964 thread
965 .update(cx, |thread, cx| {
966 thread.set_profile(AgentProfileId("test-1".into()), cx);
967 thread.send(UserMessageId::new(), ["test"], cx)
968 })
969 .unwrap();
970 cx.run_until_parked();
971
972 let mut pending_completions = fake_model.pending_completions();
973 assert_eq!(pending_completions.len(), 1);
974 let completion = pending_completions.pop().unwrap();
975 let tool_names: Vec<String> = completion
976 .tools
977 .iter()
978 .map(|tool| tool.name.clone())
979 .collect();
980 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
981 fake_model.end_last_completion_stream();
982
983 // Switch to test-2 profile, and verify that it has only the infinite tool.
984 thread
985 .update(cx, |thread, cx| {
986 thread.set_profile(AgentProfileId("test-2".into()), cx);
987 thread.send(UserMessageId::new(), ["test2"], cx)
988 })
989 .unwrap();
990 cx.run_until_parked();
991 let mut pending_completions = fake_model.pending_completions();
992 assert_eq!(pending_completions.len(), 1);
993 let completion = pending_completions.pop().unwrap();
994 let tool_names: Vec<String> = completion
995 .tools
996 .iter()
997 .map(|tool| tool.name.clone())
998 .collect();
999 assert_eq!(tool_names, vec![InfiniteTool::name()]);
1000}
1001
1002#[gpui::test]
1003async fn test_mcp_tools(cx: &mut TestAppContext) {
1004 let ThreadTest {
1005 model,
1006 thread,
1007 context_server_store,
1008 fs,
1009 ..
1010 } = setup(cx, TestModel::Fake).await;
1011 let fake_model = model.as_fake();
1012
1013 // Override profiles and wait for settings to be loaded.
1014 fs.insert_file(
1015 paths::settings_file(),
1016 json!({
1017 "agent": {
1018 "always_allow_tool_actions": true,
1019 "profiles": {
1020 "test": {
1021 "name": "Test Profile",
1022 "enable_all_context_servers": true,
1023 "tools": {
1024 EchoTool::name(): true,
1025 }
1026 },
1027 }
1028 }
1029 })
1030 .to_string()
1031 .into_bytes(),
1032 )
1033 .await;
1034 cx.run_until_parked();
1035 thread.update(cx, |thread, cx| {
1036 thread.set_profile(AgentProfileId("test".into()), cx)
1037 });
1038
1039 let mut mcp_tool_calls = setup_context_server(
1040 "test_server",
1041 vec![context_server::types::Tool {
1042 name: "echo".into(),
1043 description: None,
1044 input_schema: serde_json::to_value(EchoTool::input_schema(
1045 LanguageModelToolSchemaFormat::JsonSchema,
1046 ))
1047 .unwrap(),
1048 output_schema: None,
1049 annotations: None,
1050 }],
1051 &context_server_store,
1052 cx,
1053 );
1054
1055 let events = thread.update(cx, |thread, cx| {
1056 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1057 });
1058 cx.run_until_parked();
1059
1060 // Simulate the model calling the MCP tool.
1061 let completion = fake_model.pending_completions().pop().unwrap();
1062 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1063 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1064 LanguageModelToolUse {
1065 id: "tool_1".into(),
1066 name: "echo".into(),
1067 raw_input: json!({"text": "test"}).to_string(),
1068 input: json!({"text": "test"}),
1069 is_input_complete: true,
1070 thought_signature: None,
1071 },
1072 ));
1073 fake_model.end_last_completion_stream();
1074 cx.run_until_parked();
1075
1076 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1077 assert_eq!(tool_call_params.name, "echo");
1078 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1079 tool_call_response
1080 .send(context_server::types::CallToolResponse {
1081 content: vec![context_server::types::ToolResponseContent::Text {
1082 text: "test".into(),
1083 }],
1084 is_error: None,
1085 meta: None,
1086 structured_content: None,
1087 })
1088 .unwrap();
1089 cx.run_until_parked();
1090
1091 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1092 fake_model.send_last_completion_stream_text_chunk("Done!");
1093 fake_model.end_last_completion_stream();
1094 events.collect::<Vec<_>>().await;
1095
1096 // Send again after adding the echo tool, ensuring the name collision is resolved.
1097 let events = thread.update(cx, |thread, cx| {
1098 thread.add_tool(EchoTool);
1099 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1100 });
1101 cx.run_until_parked();
1102 let completion = fake_model.pending_completions().pop().unwrap();
1103 assert_eq!(
1104 tool_names_for_completion(&completion),
1105 vec!["echo", "test_server_echo"]
1106 );
1107 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1108 LanguageModelToolUse {
1109 id: "tool_2".into(),
1110 name: "test_server_echo".into(),
1111 raw_input: json!({"text": "mcp"}).to_string(),
1112 input: json!({"text": "mcp"}),
1113 is_input_complete: true,
1114 thought_signature: None,
1115 },
1116 ));
1117 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1118 LanguageModelToolUse {
1119 id: "tool_3".into(),
1120 name: "echo".into(),
1121 raw_input: json!({"text": "native"}).to_string(),
1122 input: json!({"text": "native"}),
1123 is_input_complete: true,
1124 thought_signature: None,
1125 },
1126 ));
1127 fake_model.end_last_completion_stream();
1128 cx.run_until_parked();
1129
1130 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1131 assert_eq!(tool_call_params.name, "echo");
1132 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1133 tool_call_response
1134 .send(context_server::types::CallToolResponse {
1135 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1136 is_error: None,
1137 meta: None,
1138 structured_content: None,
1139 })
1140 .unwrap();
1141 cx.run_until_parked();
1142
1143 // Ensure the tool results were inserted with the correct names.
1144 let completion = fake_model.pending_completions().pop().unwrap();
1145 assert_eq!(
1146 completion.messages.last().unwrap().content,
1147 vec![
1148 MessageContent::ToolResult(LanguageModelToolResult {
1149 tool_use_id: "tool_3".into(),
1150 tool_name: "echo".into(),
1151 is_error: false,
1152 content: "native".into(),
1153 output: Some("native".into()),
1154 },),
1155 MessageContent::ToolResult(LanguageModelToolResult {
1156 tool_use_id: "tool_2".into(),
1157 tool_name: "test_server_echo".into(),
1158 is_error: false,
1159 content: "mcp".into(),
1160 output: Some("mcp".into()),
1161 },),
1162 ]
1163 );
1164 fake_model.end_last_completion_stream();
1165 events.collect::<Vec<_>>().await;
1166}
1167
1168#[gpui::test]
1169async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1170 let ThreadTest {
1171 model,
1172 thread,
1173 context_server_store,
1174 fs,
1175 ..
1176 } = setup(cx, TestModel::Fake).await;
1177 let fake_model = model.as_fake();
1178
1179 // Set up a profile with all tools enabled
1180 fs.insert_file(
1181 paths::settings_file(),
1182 json!({
1183 "agent": {
1184 "profiles": {
1185 "test": {
1186 "name": "Test Profile",
1187 "enable_all_context_servers": true,
1188 "tools": {
1189 EchoTool::name(): true,
1190 DelayTool::name(): true,
1191 WordListTool::name(): true,
1192 ToolRequiringPermission::name(): true,
1193 InfiniteTool::name(): true,
1194 }
1195 },
1196 }
1197 }
1198 })
1199 .to_string()
1200 .into_bytes(),
1201 )
1202 .await;
1203 cx.run_until_parked();
1204
1205 thread.update(cx, |thread, cx| {
1206 thread.set_profile(AgentProfileId("test".into()), cx);
1207 thread.add_tool(EchoTool);
1208 thread.add_tool(DelayTool);
1209 thread.add_tool(WordListTool);
1210 thread.add_tool(ToolRequiringPermission);
1211 thread.add_tool(InfiniteTool);
1212 });
1213
1214 // Set up multiple context servers with some overlapping tool names
1215 let _server1_calls = setup_context_server(
1216 "xxx",
1217 vec![
1218 context_server::types::Tool {
1219 name: "echo".into(), // Conflicts with native EchoTool
1220 description: None,
1221 input_schema: serde_json::to_value(EchoTool::input_schema(
1222 LanguageModelToolSchemaFormat::JsonSchema,
1223 ))
1224 .unwrap(),
1225 output_schema: None,
1226 annotations: None,
1227 },
1228 context_server::types::Tool {
1229 name: "unique_tool_1".into(),
1230 description: None,
1231 input_schema: json!({"type": "object", "properties": {}}),
1232 output_schema: None,
1233 annotations: None,
1234 },
1235 ],
1236 &context_server_store,
1237 cx,
1238 );
1239
1240 let _server2_calls = setup_context_server(
1241 "yyy",
1242 vec![
1243 context_server::types::Tool {
1244 name: "echo".into(), // Also conflicts with native EchoTool
1245 description: None,
1246 input_schema: serde_json::to_value(EchoTool::input_schema(
1247 LanguageModelToolSchemaFormat::JsonSchema,
1248 ))
1249 .unwrap(),
1250 output_schema: None,
1251 annotations: None,
1252 },
1253 context_server::types::Tool {
1254 name: "unique_tool_2".into(),
1255 description: None,
1256 input_schema: json!({"type": "object", "properties": {}}),
1257 output_schema: None,
1258 annotations: None,
1259 },
1260 context_server::types::Tool {
1261 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1262 description: None,
1263 input_schema: json!({"type": "object", "properties": {}}),
1264 output_schema: None,
1265 annotations: None,
1266 },
1267 context_server::types::Tool {
1268 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1269 description: None,
1270 input_schema: json!({"type": "object", "properties": {}}),
1271 output_schema: None,
1272 annotations: None,
1273 },
1274 ],
1275 &context_server_store,
1276 cx,
1277 );
1278 let _server3_calls = setup_context_server(
1279 "zzz",
1280 vec![
1281 context_server::types::Tool {
1282 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1283 description: None,
1284 input_schema: json!({"type": "object", "properties": {}}),
1285 output_schema: None,
1286 annotations: None,
1287 },
1288 context_server::types::Tool {
1289 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1290 description: None,
1291 input_schema: json!({"type": "object", "properties": {}}),
1292 output_schema: None,
1293 annotations: None,
1294 },
1295 context_server::types::Tool {
1296 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1297 description: None,
1298 input_schema: json!({"type": "object", "properties": {}}),
1299 output_schema: None,
1300 annotations: None,
1301 },
1302 ],
1303 &context_server_store,
1304 cx,
1305 );
1306
1307 thread
1308 .update(cx, |thread, cx| {
1309 thread.send(UserMessageId::new(), ["Go"], cx)
1310 })
1311 .unwrap();
1312 cx.run_until_parked();
1313 let completion = fake_model.pending_completions().pop().unwrap();
1314 assert_eq!(
1315 tool_names_for_completion(&completion),
1316 vec![
1317 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1318 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1319 "delay",
1320 "echo",
1321 "infinite",
1322 "tool_requiring_permission",
1323 "unique_tool_1",
1324 "unique_tool_2",
1325 "word_list",
1326 "xxx_echo",
1327 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1328 "yyy_echo",
1329 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1330 ]
1331 );
1332}
1333
1334#[gpui::test]
1335#[cfg_attr(not(feature = "e2e"), ignore)]
1336async fn test_cancellation(cx: &mut TestAppContext) {
1337 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1338
1339 let mut events = thread
1340 .update(cx, |thread, cx| {
1341 thread.add_tool(InfiniteTool);
1342 thread.add_tool(EchoTool);
1343 thread.send(
1344 UserMessageId::new(),
1345 ["Call the echo tool, then call the infinite tool, then explain their output"],
1346 cx,
1347 )
1348 })
1349 .unwrap();
1350
1351 // Wait until both tools are called.
1352 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1353 let mut echo_id = None;
1354 let mut echo_completed = false;
1355 while let Some(event) = events.next().await {
1356 match event.unwrap() {
1357 ThreadEvent::ToolCall(tool_call) => {
1358 assert_eq!(tool_call.title, expected_tools.remove(0));
1359 if tool_call.title == "Echo" {
1360 echo_id = Some(tool_call.id);
1361 }
1362 }
1363 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1364 acp::ToolCallUpdate {
1365 id,
1366 fields:
1367 acp::ToolCallUpdateFields {
1368 status: Some(acp::ToolCallStatus::Completed),
1369 ..
1370 },
1371 meta: None,
1372 },
1373 )) if Some(&id) == echo_id.as_ref() => {
1374 echo_completed = true;
1375 }
1376 _ => {}
1377 }
1378
1379 if expected_tools.is_empty() && echo_completed {
1380 break;
1381 }
1382 }
1383
1384 // Cancel the current send and ensure that the event stream is closed, even
1385 // if one of the tools is still running.
1386 thread.update(cx, |thread, cx| thread.cancel(cx));
1387 let events = events.collect::<Vec<_>>().await;
1388 let last_event = events.last();
1389 assert!(
1390 matches!(
1391 last_event,
1392 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1393 ),
1394 "unexpected event {last_event:?}"
1395 );
1396
1397 // Ensure we can still send a new message after cancellation.
1398 let events = thread
1399 .update(cx, |thread, cx| {
1400 thread.send(
1401 UserMessageId::new(),
1402 ["Testing: reply with 'Hello' then stop."],
1403 cx,
1404 )
1405 })
1406 .unwrap()
1407 .collect::<Vec<_>>()
1408 .await;
1409 thread.update(cx, |thread, _cx| {
1410 let message = thread.last_message().unwrap();
1411 let agent_message = message.as_agent_message().unwrap();
1412 assert_eq!(
1413 agent_message.content,
1414 vec![AgentMessageContent::Text("Hello".to_string())]
1415 );
1416 });
1417 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1418}
1419
1420#[gpui::test]
1421async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1422 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1423 let fake_model = model.as_fake();
1424
1425 let events_1 = thread
1426 .update(cx, |thread, cx| {
1427 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1428 })
1429 .unwrap();
1430 cx.run_until_parked();
1431 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1432 cx.run_until_parked();
1433
1434 let events_2 = thread
1435 .update(cx, |thread, cx| {
1436 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1437 })
1438 .unwrap();
1439 cx.run_until_parked();
1440 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1441 fake_model
1442 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1443 fake_model.end_last_completion_stream();
1444
1445 let events_1 = events_1.collect::<Vec<_>>().await;
1446 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1447 let events_2 = events_2.collect::<Vec<_>>().await;
1448 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1449}
1450
1451#[gpui::test]
1452async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1453 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1454 let fake_model = model.as_fake();
1455
1456 let events_1 = thread
1457 .update(cx, |thread, cx| {
1458 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1459 })
1460 .unwrap();
1461 cx.run_until_parked();
1462 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1463 fake_model
1464 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1465 fake_model.end_last_completion_stream();
1466 let events_1 = events_1.collect::<Vec<_>>().await;
1467
1468 let events_2 = thread
1469 .update(cx, |thread, cx| {
1470 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1471 })
1472 .unwrap();
1473 cx.run_until_parked();
1474 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1475 fake_model
1476 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1477 fake_model.end_last_completion_stream();
1478 let events_2 = events_2.collect::<Vec<_>>().await;
1479
1480 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1481 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1482}
1483
1484#[gpui::test]
1485async fn test_refusal(cx: &mut TestAppContext) {
1486 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1487 let fake_model = model.as_fake();
1488
1489 let events = thread
1490 .update(cx, |thread, cx| {
1491 thread.send(UserMessageId::new(), ["Hello"], cx)
1492 })
1493 .unwrap();
1494 cx.run_until_parked();
1495 thread.read_with(cx, |thread, _| {
1496 assert_eq!(
1497 thread.to_markdown(),
1498 indoc! {"
1499 ## User
1500
1501 Hello
1502 "}
1503 );
1504 });
1505
1506 fake_model.send_last_completion_stream_text_chunk("Hey!");
1507 cx.run_until_parked();
1508 thread.read_with(cx, |thread, _| {
1509 assert_eq!(
1510 thread.to_markdown(),
1511 indoc! {"
1512 ## User
1513
1514 Hello
1515
1516 ## Assistant
1517
1518 Hey!
1519 "}
1520 );
1521 });
1522
1523 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1524 fake_model
1525 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1526 let events = events.collect::<Vec<_>>().await;
1527 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1528 thread.read_with(cx, |thread, _| {
1529 assert_eq!(thread.to_markdown(), "");
1530 });
1531}
1532
1533#[gpui::test]
1534async fn test_truncate_first_message(cx: &mut TestAppContext) {
1535 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1536 let fake_model = model.as_fake();
1537
1538 let message_id = UserMessageId::new();
1539 thread
1540 .update(cx, |thread, cx| {
1541 thread.send(message_id.clone(), ["Hello"], cx)
1542 })
1543 .unwrap();
1544 cx.run_until_parked();
1545 thread.read_with(cx, |thread, _| {
1546 assert_eq!(
1547 thread.to_markdown(),
1548 indoc! {"
1549 ## User
1550
1551 Hello
1552 "}
1553 );
1554 assert_eq!(thread.latest_token_usage(), None);
1555 });
1556
1557 fake_model.send_last_completion_stream_text_chunk("Hey!");
1558 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1559 language_model::TokenUsage {
1560 input_tokens: 32_000,
1561 output_tokens: 16_000,
1562 cache_creation_input_tokens: 0,
1563 cache_read_input_tokens: 0,
1564 },
1565 ));
1566 cx.run_until_parked();
1567 thread.read_with(cx, |thread, _| {
1568 assert_eq!(
1569 thread.to_markdown(),
1570 indoc! {"
1571 ## User
1572
1573 Hello
1574
1575 ## Assistant
1576
1577 Hey!
1578 "}
1579 );
1580 assert_eq!(
1581 thread.latest_token_usage(),
1582 Some(acp_thread::TokenUsage {
1583 used_tokens: 32_000 + 16_000,
1584 max_tokens: 1_000_000,
1585 })
1586 );
1587 });
1588
1589 thread
1590 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1591 .unwrap();
1592 cx.run_until_parked();
1593 thread.read_with(cx, |thread, _| {
1594 assert_eq!(thread.to_markdown(), "");
1595 assert_eq!(thread.latest_token_usage(), None);
1596 });
1597
1598 // Ensure we can still send a new message after truncation.
1599 thread
1600 .update(cx, |thread, cx| {
1601 thread.send(UserMessageId::new(), ["Hi"], cx)
1602 })
1603 .unwrap();
1604 thread.update(cx, |thread, _cx| {
1605 assert_eq!(
1606 thread.to_markdown(),
1607 indoc! {"
1608 ## User
1609
1610 Hi
1611 "}
1612 );
1613 });
1614 cx.run_until_parked();
1615 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1616 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1617 language_model::TokenUsage {
1618 input_tokens: 40_000,
1619 output_tokens: 20_000,
1620 cache_creation_input_tokens: 0,
1621 cache_read_input_tokens: 0,
1622 },
1623 ));
1624 cx.run_until_parked();
1625 thread.read_with(cx, |thread, _| {
1626 assert_eq!(
1627 thread.to_markdown(),
1628 indoc! {"
1629 ## User
1630
1631 Hi
1632
1633 ## Assistant
1634
1635 Ahoy!
1636 "}
1637 );
1638
1639 assert_eq!(
1640 thread.latest_token_usage(),
1641 Some(acp_thread::TokenUsage {
1642 used_tokens: 40_000 + 20_000,
1643 max_tokens: 1_000_000,
1644 })
1645 );
1646 });
1647}
1648
1649#[gpui::test]
1650async fn test_truncate_second_message(cx: &mut TestAppContext) {
1651 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1652 let fake_model = model.as_fake();
1653
1654 thread
1655 .update(cx, |thread, cx| {
1656 thread.send(UserMessageId::new(), ["Message 1"], cx)
1657 })
1658 .unwrap();
1659 cx.run_until_parked();
1660 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1661 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1662 language_model::TokenUsage {
1663 input_tokens: 32_000,
1664 output_tokens: 16_000,
1665 cache_creation_input_tokens: 0,
1666 cache_read_input_tokens: 0,
1667 },
1668 ));
1669 fake_model.end_last_completion_stream();
1670 cx.run_until_parked();
1671
1672 let assert_first_message_state = |cx: &mut TestAppContext| {
1673 thread.clone().read_with(cx, |thread, _| {
1674 assert_eq!(
1675 thread.to_markdown(),
1676 indoc! {"
1677 ## User
1678
1679 Message 1
1680
1681 ## Assistant
1682
1683 Message 1 response
1684 "}
1685 );
1686
1687 assert_eq!(
1688 thread.latest_token_usage(),
1689 Some(acp_thread::TokenUsage {
1690 used_tokens: 32_000 + 16_000,
1691 max_tokens: 1_000_000,
1692 })
1693 );
1694 });
1695 };
1696
1697 assert_first_message_state(cx);
1698
1699 let second_message_id = UserMessageId::new();
1700 thread
1701 .update(cx, |thread, cx| {
1702 thread.send(second_message_id.clone(), ["Message 2"], cx)
1703 })
1704 .unwrap();
1705 cx.run_until_parked();
1706
1707 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1708 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1709 language_model::TokenUsage {
1710 input_tokens: 40_000,
1711 output_tokens: 20_000,
1712 cache_creation_input_tokens: 0,
1713 cache_read_input_tokens: 0,
1714 },
1715 ));
1716 fake_model.end_last_completion_stream();
1717 cx.run_until_parked();
1718
1719 thread.read_with(cx, |thread, _| {
1720 assert_eq!(
1721 thread.to_markdown(),
1722 indoc! {"
1723 ## User
1724
1725 Message 1
1726
1727 ## Assistant
1728
1729 Message 1 response
1730
1731 ## User
1732
1733 Message 2
1734
1735 ## Assistant
1736
1737 Message 2 response
1738 "}
1739 );
1740
1741 assert_eq!(
1742 thread.latest_token_usage(),
1743 Some(acp_thread::TokenUsage {
1744 used_tokens: 40_000 + 20_000,
1745 max_tokens: 1_000_000,
1746 })
1747 );
1748 });
1749
1750 thread
1751 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1752 .unwrap();
1753 cx.run_until_parked();
1754
1755 assert_first_message_state(cx);
1756}
1757
1758#[gpui::test]
1759async fn test_title_generation(cx: &mut TestAppContext) {
1760 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1761 let fake_model = model.as_fake();
1762
1763 let summary_model = Arc::new(FakeLanguageModel::default());
1764 thread.update(cx, |thread, cx| {
1765 thread.set_summarization_model(Some(summary_model.clone()), cx)
1766 });
1767
1768 let send = thread
1769 .update(cx, |thread, cx| {
1770 thread.send(UserMessageId::new(), ["Hello"], cx)
1771 })
1772 .unwrap();
1773 cx.run_until_parked();
1774
1775 fake_model.send_last_completion_stream_text_chunk("Hey!");
1776 fake_model.end_last_completion_stream();
1777 cx.run_until_parked();
1778 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1779
1780 // Ensure the summary model has been invoked to generate a title.
1781 summary_model.send_last_completion_stream_text_chunk("Hello ");
1782 summary_model.send_last_completion_stream_text_chunk("world\nG");
1783 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1784 summary_model.end_last_completion_stream();
1785 send.collect::<Vec<_>>().await;
1786 cx.run_until_parked();
1787 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1788
1789 // Send another message, ensuring no title is generated this time.
1790 let send = thread
1791 .update(cx, |thread, cx| {
1792 thread.send(UserMessageId::new(), ["Hello again"], cx)
1793 })
1794 .unwrap();
1795 cx.run_until_parked();
1796 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1797 fake_model.end_last_completion_stream();
1798 cx.run_until_parked();
1799 assert_eq!(summary_model.pending_completions(), Vec::new());
1800 send.collect::<Vec<_>>().await;
1801 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1802}
1803
1804#[gpui::test]
1805async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1806 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1807 let fake_model = model.as_fake();
1808
1809 let _events = thread
1810 .update(cx, |thread, cx| {
1811 thread.add_tool(ToolRequiringPermission);
1812 thread.add_tool(EchoTool);
1813 thread.send(UserMessageId::new(), ["Hey!"], cx)
1814 })
1815 .unwrap();
1816 cx.run_until_parked();
1817
1818 let permission_tool_use = LanguageModelToolUse {
1819 id: "tool_id_1".into(),
1820 name: ToolRequiringPermission::name().into(),
1821 raw_input: "{}".into(),
1822 input: json!({}),
1823 is_input_complete: true,
1824 thought_signature: None,
1825 };
1826 let echo_tool_use = LanguageModelToolUse {
1827 id: "tool_id_2".into(),
1828 name: EchoTool::name().into(),
1829 raw_input: json!({"text": "test"}).to_string(),
1830 input: json!({"text": "test"}),
1831 is_input_complete: true,
1832 thought_signature: None,
1833 };
1834 fake_model.send_last_completion_stream_text_chunk("Hi!");
1835 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1836 permission_tool_use,
1837 ));
1838 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1839 echo_tool_use.clone(),
1840 ));
1841 fake_model.end_last_completion_stream();
1842 cx.run_until_parked();
1843
1844 // Ensure pending tools are skipped when building a request.
1845 let request = thread
1846 .read_with(cx, |thread, cx| {
1847 thread.build_completion_request(CompletionIntent::EditFile, cx)
1848 })
1849 .unwrap();
1850 assert_eq!(
1851 request.messages[1..],
1852 vec![
1853 LanguageModelRequestMessage {
1854 role: Role::User,
1855 content: vec!["Hey!".into()],
1856 cache: true,
1857 reasoning_details: None,
1858 },
1859 LanguageModelRequestMessage {
1860 role: Role::Assistant,
1861 content: vec![
1862 MessageContent::Text("Hi!".into()),
1863 MessageContent::ToolUse(echo_tool_use.clone())
1864 ],
1865 cache: false,
1866 reasoning_details: None,
1867 },
1868 LanguageModelRequestMessage {
1869 role: Role::User,
1870 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1871 tool_use_id: echo_tool_use.id.clone(),
1872 tool_name: echo_tool_use.name,
1873 is_error: false,
1874 content: "test".into(),
1875 output: Some("test".into())
1876 })],
1877 cache: false,
1878 reasoning_details: None,
1879 },
1880 ],
1881 );
1882}
1883
1884#[gpui::test]
1885async fn test_agent_connection(cx: &mut TestAppContext) {
1886 cx.update(settings::init);
1887 let templates = Templates::new();
1888
1889 // Initialize language model system with test provider
1890 cx.update(|cx| {
1891 gpui_tokio::init(cx);
1892
1893 let http_client = FakeHttpClient::with_404_response();
1894 let clock = Arc::new(clock::FakeSystemClock::new());
1895 let client = Client::new(clock, http_client, cx);
1896 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1897 language_model::init(client.clone(), cx);
1898 language_models::init(user_store, client.clone(), cx);
1899 LanguageModelRegistry::test(cx);
1900 });
1901 cx.executor().forbid_parking();
1902
1903 // Create a project for new_thread
1904 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1905 fake_fs.insert_tree(path!("/test"), json!({})).await;
1906 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1907 let cwd = Path::new("/test");
1908 let text_thread_store =
1909 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1910 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1911
1912 // Create agent and connection
1913 let agent = NativeAgent::new(
1914 project.clone(),
1915 history_store,
1916 templates.clone(),
1917 None,
1918 fake_fs.clone(),
1919 &mut cx.to_async(),
1920 )
1921 .await
1922 .unwrap();
1923 let connection = NativeAgentConnection(agent.clone());
1924
1925 // Create a thread using new_thread
1926 let connection_rc = Rc::new(connection.clone());
1927 let acp_thread = cx
1928 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1929 .await
1930 .expect("new_thread should succeed");
1931
1932 // Get the session_id from the AcpThread
1933 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1934
1935 // Test model_selector returns Some
1936 let selector_opt = connection.model_selector(&session_id);
1937 assert!(
1938 selector_opt.is_some(),
1939 "agent should always support ModelSelector"
1940 );
1941 let selector = selector_opt.unwrap();
1942
1943 // Test list_models
1944 let listed_models = cx
1945 .update(|cx| selector.list_models(cx))
1946 .await
1947 .expect("list_models should succeed");
1948 let AgentModelList::Grouped(listed_models) = listed_models else {
1949 panic!("Unexpected model list type");
1950 };
1951 assert!(!listed_models.is_empty(), "should have at least one model");
1952 assert_eq!(
1953 listed_models[&AgentModelGroupName("Fake".into())][0]
1954 .id
1955 .0
1956 .as_ref(),
1957 "fake/fake"
1958 );
1959
1960 // Test selected_model returns the default
1961 let model = cx
1962 .update(|cx| selector.selected_model(cx))
1963 .await
1964 .expect("selected_model should succeed");
1965 let model = cx
1966 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1967 .unwrap();
1968 let model = model.as_fake();
1969 assert_eq!(model.id().0, "fake", "should return default model");
1970
1971 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1972 cx.run_until_parked();
1973 model.send_last_completion_stream_text_chunk("def");
1974 cx.run_until_parked();
1975 acp_thread.read_with(cx, |thread, cx| {
1976 assert_eq!(
1977 thread.to_markdown(cx),
1978 indoc! {"
1979 ## User
1980
1981 abc
1982
1983 ## Assistant
1984
1985 def
1986
1987 "}
1988 )
1989 });
1990
1991 // Test cancel
1992 cx.update(|cx| connection.cancel(&session_id, cx));
1993 request.await.expect("prompt should fail gracefully");
1994
1995 // Ensure that dropping the ACP thread causes the native thread to be
1996 // dropped as well.
1997 cx.update(|_| drop(acp_thread));
1998 let result = cx
1999 .update(|cx| {
2000 connection.prompt(
2001 Some(acp_thread::UserMessageId::new()),
2002 acp::PromptRequest {
2003 session_id: session_id.clone(),
2004 prompt: vec!["ghi".into()],
2005 meta: None,
2006 },
2007 cx,
2008 )
2009 })
2010 .await;
2011 assert_eq!(
2012 result.as_ref().unwrap_err().to_string(),
2013 "Session not found",
2014 "unexpected result: {:?}",
2015 result
2016 );
2017}
2018
2019#[gpui::test]
2020async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2021 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2022 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2023 let fake_model = model.as_fake();
2024
2025 let mut events = thread
2026 .update(cx, |thread, cx| {
2027 thread.send(UserMessageId::new(), ["Think"], cx)
2028 })
2029 .unwrap();
2030 cx.run_until_parked();
2031
2032 // Simulate streaming partial input.
2033 let input = json!({});
2034 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2035 LanguageModelToolUse {
2036 id: "1".into(),
2037 name: ThinkingTool::name().into(),
2038 raw_input: input.to_string(),
2039 input,
2040 is_input_complete: false,
2041 thought_signature: None,
2042 },
2043 ));
2044
2045 // Input streaming completed
2046 let input = json!({ "content": "Thinking hard!" });
2047 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2048 LanguageModelToolUse {
2049 id: "1".into(),
2050 name: "thinking".into(),
2051 raw_input: input.to_string(),
2052 input,
2053 is_input_complete: true,
2054 thought_signature: None,
2055 },
2056 ));
2057 fake_model.end_last_completion_stream();
2058 cx.run_until_parked();
2059
2060 let tool_call = expect_tool_call(&mut events).await;
2061 assert_eq!(
2062 tool_call,
2063 acp::ToolCall {
2064 id: acp::ToolCallId("1".into()),
2065 title: "Thinking".into(),
2066 kind: acp::ToolKind::Think,
2067 status: acp::ToolCallStatus::Pending,
2068 content: vec![],
2069 locations: vec![],
2070 raw_input: Some(json!({})),
2071 raw_output: None,
2072 meta: Some(json!({ "tool_name": "thinking" })),
2073 }
2074 );
2075 let update = expect_tool_call_update_fields(&mut events).await;
2076 assert_eq!(
2077 update,
2078 acp::ToolCallUpdate {
2079 id: acp::ToolCallId("1".into()),
2080 fields: acp::ToolCallUpdateFields {
2081 title: Some("Thinking".into()),
2082 kind: Some(acp::ToolKind::Think),
2083 raw_input: Some(json!({ "content": "Thinking hard!" })),
2084 ..Default::default()
2085 },
2086 meta: None,
2087 }
2088 );
2089 let update = expect_tool_call_update_fields(&mut events).await;
2090 assert_eq!(
2091 update,
2092 acp::ToolCallUpdate {
2093 id: acp::ToolCallId("1".into()),
2094 fields: acp::ToolCallUpdateFields {
2095 status: Some(acp::ToolCallStatus::InProgress),
2096 ..Default::default()
2097 },
2098 meta: None,
2099 }
2100 );
2101 let update = expect_tool_call_update_fields(&mut events).await;
2102 assert_eq!(
2103 update,
2104 acp::ToolCallUpdate {
2105 id: acp::ToolCallId("1".into()),
2106 fields: acp::ToolCallUpdateFields {
2107 content: Some(vec!["Thinking hard!".into()]),
2108 ..Default::default()
2109 },
2110 meta: None,
2111 }
2112 );
2113 let update = expect_tool_call_update_fields(&mut events).await;
2114 assert_eq!(
2115 update,
2116 acp::ToolCallUpdate {
2117 id: acp::ToolCallId("1".into()),
2118 fields: acp::ToolCallUpdateFields {
2119 status: Some(acp::ToolCallStatus::Completed),
2120 raw_output: Some("Finished thinking.".into()),
2121 ..Default::default()
2122 },
2123 meta: None,
2124 }
2125 );
2126}
2127
2128#[gpui::test]
2129async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2130 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2131 let fake_model = model.as_fake();
2132
2133 let mut events = thread
2134 .update(cx, |thread, cx| {
2135 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2136 thread.send(UserMessageId::new(), ["Hello!"], cx)
2137 })
2138 .unwrap();
2139 cx.run_until_parked();
2140
2141 fake_model.send_last_completion_stream_text_chunk("Hey!");
2142 fake_model.end_last_completion_stream();
2143
2144 let mut retry_events = Vec::new();
2145 while let Some(Ok(event)) = events.next().await {
2146 match event {
2147 ThreadEvent::Retry(retry_status) => {
2148 retry_events.push(retry_status);
2149 }
2150 ThreadEvent::Stop(..) => break,
2151 _ => {}
2152 }
2153 }
2154
2155 assert_eq!(retry_events.len(), 0);
2156 thread.read_with(cx, |thread, _cx| {
2157 assert_eq!(
2158 thread.to_markdown(),
2159 indoc! {"
2160 ## User
2161
2162 Hello!
2163
2164 ## Assistant
2165
2166 Hey!
2167 "}
2168 )
2169 });
2170}
2171
2172#[gpui::test]
2173async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2174 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2175 let fake_model = model.as_fake();
2176
2177 let mut events = thread
2178 .update(cx, |thread, cx| {
2179 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2180 thread.send(UserMessageId::new(), ["Hello!"], cx)
2181 })
2182 .unwrap();
2183 cx.run_until_parked();
2184
2185 fake_model.send_last_completion_stream_text_chunk("Hey,");
2186 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2187 provider: LanguageModelProviderName::new("Anthropic"),
2188 retry_after: Some(Duration::from_secs(3)),
2189 });
2190 fake_model.end_last_completion_stream();
2191
2192 cx.executor().advance_clock(Duration::from_secs(3));
2193 cx.run_until_parked();
2194
2195 fake_model.send_last_completion_stream_text_chunk("there!");
2196 fake_model.end_last_completion_stream();
2197 cx.run_until_parked();
2198
2199 let mut retry_events = Vec::new();
2200 while let Some(Ok(event)) = events.next().await {
2201 match event {
2202 ThreadEvent::Retry(retry_status) => {
2203 retry_events.push(retry_status);
2204 }
2205 ThreadEvent::Stop(..) => break,
2206 _ => {}
2207 }
2208 }
2209
2210 assert_eq!(retry_events.len(), 1);
2211 assert!(matches!(
2212 retry_events[0],
2213 acp_thread::RetryStatus { attempt: 1, .. }
2214 ));
2215 thread.read_with(cx, |thread, _cx| {
2216 assert_eq!(
2217 thread.to_markdown(),
2218 indoc! {"
2219 ## User
2220
2221 Hello!
2222
2223 ## Assistant
2224
2225 Hey,
2226
2227 [resume]
2228
2229 ## Assistant
2230
2231 there!
2232 "}
2233 )
2234 });
2235}
2236
2237#[gpui::test]
2238async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2239 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2240 let fake_model = model.as_fake();
2241
2242 let events = thread
2243 .update(cx, |thread, cx| {
2244 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2245 thread.add_tool(EchoTool);
2246 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2247 })
2248 .unwrap();
2249 cx.run_until_parked();
2250
2251 let tool_use_1 = LanguageModelToolUse {
2252 id: "tool_1".into(),
2253 name: EchoTool::name().into(),
2254 raw_input: json!({"text": "test"}).to_string(),
2255 input: json!({"text": "test"}),
2256 is_input_complete: true,
2257 thought_signature: None,
2258 };
2259 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2260 tool_use_1.clone(),
2261 ));
2262 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2263 provider: LanguageModelProviderName::new("Anthropic"),
2264 retry_after: Some(Duration::from_secs(3)),
2265 });
2266 fake_model.end_last_completion_stream();
2267
2268 cx.executor().advance_clock(Duration::from_secs(3));
2269 let completion = fake_model.pending_completions().pop().unwrap();
2270 assert_eq!(
2271 completion.messages[1..],
2272 vec![
2273 LanguageModelRequestMessage {
2274 role: Role::User,
2275 content: vec!["Call the echo tool!".into()],
2276 cache: false,
2277 reasoning_details: None,
2278 },
2279 LanguageModelRequestMessage {
2280 role: Role::Assistant,
2281 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2282 cache: false,
2283 reasoning_details: None,
2284 },
2285 LanguageModelRequestMessage {
2286 role: Role::User,
2287 content: vec![language_model::MessageContent::ToolResult(
2288 LanguageModelToolResult {
2289 tool_use_id: tool_use_1.id.clone(),
2290 tool_name: tool_use_1.name.clone(),
2291 is_error: false,
2292 content: "test".into(),
2293 output: Some("test".into())
2294 }
2295 )],
2296 cache: true,
2297 reasoning_details: None,
2298 },
2299 ]
2300 );
2301
2302 fake_model.send_last_completion_stream_text_chunk("Done");
2303 fake_model.end_last_completion_stream();
2304 cx.run_until_parked();
2305 events.collect::<Vec<_>>().await;
2306 thread.read_with(cx, |thread, _cx| {
2307 assert_eq!(
2308 thread.last_message(),
2309 Some(Message::Agent(AgentMessage {
2310 content: vec![AgentMessageContent::Text("Done".into())],
2311 tool_results: IndexMap::default(),
2312 reasoning_details: None,
2313 }))
2314 );
2315 })
2316}
2317
2318#[gpui::test]
2319async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2320 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2321 let fake_model = model.as_fake();
2322
2323 let mut events = thread
2324 .update(cx, |thread, cx| {
2325 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2326 thread.send(UserMessageId::new(), ["Hello!"], cx)
2327 })
2328 .unwrap();
2329 cx.run_until_parked();
2330
2331 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2332 fake_model.send_last_completion_stream_error(
2333 LanguageModelCompletionError::ServerOverloaded {
2334 provider: LanguageModelProviderName::new("Anthropic"),
2335 retry_after: Some(Duration::from_secs(3)),
2336 },
2337 );
2338 fake_model.end_last_completion_stream();
2339 cx.executor().advance_clock(Duration::from_secs(3));
2340 cx.run_until_parked();
2341 }
2342
2343 let mut errors = Vec::new();
2344 let mut retry_events = Vec::new();
2345 while let Some(event) = events.next().await {
2346 match event {
2347 Ok(ThreadEvent::Retry(retry_status)) => {
2348 retry_events.push(retry_status);
2349 }
2350 Ok(ThreadEvent::Stop(..)) => break,
2351 Err(error) => errors.push(error),
2352 _ => {}
2353 }
2354 }
2355
2356 assert_eq!(
2357 retry_events.len(),
2358 crate::thread::MAX_RETRY_ATTEMPTS as usize
2359 );
2360 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2361 assert_eq!(retry_events[i].attempt, i + 1);
2362 }
2363 assert_eq!(errors.len(), 1);
2364 let error = errors[0]
2365 .downcast_ref::<LanguageModelCompletionError>()
2366 .unwrap();
2367 assert!(matches!(
2368 error,
2369 LanguageModelCompletionError::ServerOverloaded { .. }
2370 ));
2371}
2372
2373/// Filters out the stop events for asserting against in tests
2374fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2375 result_events
2376 .into_iter()
2377 .filter_map(|event| match event.unwrap() {
2378 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2379 _ => None,
2380 })
2381 .collect()
2382}
2383
2384struct ThreadTest {
2385 model: Arc<dyn LanguageModel>,
2386 thread: Entity<Thread>,
2387 project_context: Entity<ProjectContext>,
2388 context_server_store: Entity<ContextServerStore>,
2389 fs: Arc<FakeFs>,
2390}
2391
2392enum TestModel {
2393 Sonnet4,
2394 Fake,
2395}
2396
2397impl TestModel {
2398 fn id(&self) -> LanguageModelId {
2399 match self {
2400 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2401 TestModel::Fake => unreachable!(),
2402 }
2403 }
2404}
2405
2406async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2407 cx.executor().allow_parking();
2408
2409 let fs = FakeFs::new(cx.background_executor.clone());
2410 fs.create_dir(paths::settings_file().parent().unwrap())
2411 .await
2412 .unwrap();
2413 fs.insert_file(
2414 paths::settings_file(),
2415 json!({
2416 "agent": {
2417 "default_profile": "test-profile",
2418 "profiles": {
2419 "test-profile": {
2420 "name": "Test Profile",
2421 "tools": {
2422 EchoTool::name(): true,
2423 DelayTool::name(): true,
2424 WordListTool::name(): true,
2425 ToolRequiringPermission::name(): true,
2426 InfiniteTool::name(): true,
2427 ThinkingTool::name(): true,
2428 }
2429 }
2430 }
2431 }
2432 })
2433 .to_string()
2434 .into_bytes(),
2435 )
2436 .await;
2437
2438 cx.update(|cx| {
2439 settings::init(cx);
2440
2441 match model {
2442 TestModel::Fake => {}
2443 TestModel::Sonnet4 => {
2444 gpui_tokio::init(cx);
2445 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2446 cx.set_http_client(Arc::new(http_client));
2447 let client = Client::production(cx);
2448 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2449 language_model::init(client.clone(), cx);
2450 language_models::init(user_store, client.clone(), cx);
2451 }
2452 };
2453
2454 watch_settings(fs.clone(), cx);
2455 });
2456
2457 let templates = Templates::new();
2458
2459 fs.insert_tree(path!("/test"), json!({})).await;
2460 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2461
2462 let model = cx
2463 .update(|cx| {
2464 if let TestModel::Fake = model {
2465 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2466 } else {
2467 let model_id = model.id();
2468 let models = LanguageModelRegistry::read_global(cx);
2469 let model = models
2470 .available_models(cx)
2471 .find(|model| model.id() == model_id)
2472 .unwrap();
2473
2474 let provider = models.provider(&model.provider_id()).unwrap();
2475 let authenticated = provider.authenticate(cx);
2476
2477 cx.spawn(async move |_cx| {
2478 authenticated.await.unwrap();
2479 model
2480 })
2481 }
2482 })
2483 .await;
2484
2485 let project_context = cx.new(|_cx| ProjectContext::default());
2486 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2487 let context_server_registry =
2488 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2489 let thread = cx.new(|cx| {
2490 Thread::new(
2491 project,
2492 project_context.clone(),
2493 context_server_registry,
2494 templates,
2495 Some(model.clone()),
2496 cx,
2497 )
2498 });
2499 ThreadTest {
2500 model,
2501 thread,
2502 project_context,
2503 context_server_store,
2504 fs,
2505 }
2506}
2507
2508#[cfg(test)]
2509#[ctor::ctor]
2510fn init_logger() {
2511 if std::env::var("RUST_LOG").is_ok() {
2512 env_logger::init();
2513 }
2514}
2515
2516fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2517 let fs = fs.clone();
2518 cx.spawn({
2519 async move |cx| {
2520 let mut new_settings_content_rx = settings::watch_config_file(
2521 cx.background_executor(),
2522 fs,
2523 paths::settings_file().clone(),
2524 );
2525
2526 while let Some(new_settings_content) = new_settings_content_rx.next().await {
2527 cx.update(|cx| {
2528 SettingsStore::update_global(cx, |settings, cx| {
2529 settings.set_user_settings(&new_settings_content, cx)
2530 })
2531 })
2532 .ok();
2533 }
2534 }
2535 })
2536 .detach();
2537}
2538
2539fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2540 completion
2541 .tools
2542 .iter()
2543 .map(|tool| tool.name.clone())
2544 .collect()
2545}
2546
2547fn setup_context_server(
2548 name: &'static str,
2549 tools: Vec<context_server::types::Tool>,
2550 context_server_store: &Entity<ContextServerStore>,
2551 cx: &mut TestAppContext,
2552) -> mpsc::UnboundedReceiver<(
2553 context_server::types::CallToolParams,
2554 oneshot::Sender<context_server::types::CallToolResponse>,
2555)> {
2556 cx.update(|cx| {
2557 let mut settings = ProjectSettings::get_global(cx).clone();
2558 settings.context_servers.insert(
2559 name.into(),
2560 project::project_settings::ContextServerSettings::Custom {
2561 enabled: true,
2562 command: ContextServerCommand {
2563 path: "somebinary".into(),
2564 args: Vec::new(),
2565 env: None,
2566 timeout: None,
2567 },
2568 },
2569 );
2570 ProjectSettings::override_global(settings, cx);
2571 });
2572
2573 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2574 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2575 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2576 context_server::types::InitializeResponse {
2577 protocol_version: context_server::types::ProtocolVersion(
2578 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2579 ),
2580 server_info: context_server::types::Implementation {
2581 name: name.into(),
2582 version: "1.0.0".to_string(),
2583 },
2584 capabilities: context_server::types::ServerCapabilities {
2585 tools: Some(context_server::types::ToolsCapabilities {
2586 list_changed: Some(true),
2587 }),
2588 ..Default::default()
2589 },
2590 meta: None,
2591 }
2592 })
2593 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2594 let tools = tools.clone();
2595 async move {
2596 context_server::types::ListToolsResponse {
2597 tools,
2598 next_cursor: None,
2599 meta: None,
2600 }
2601 }
2602 })
2603 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2604 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2605 async move {
2606 let (response_tx, response_rx) = oneshot::channel();
2607 mcp_tool_calls_tx
2608 .unbounded_send((params, response_tx))
2609 .unwrap();
2610 response_rx.await.unwrap()
2611 }
2612 });
2613 context_server_store.update(cx, |store, cx| {
2614 store.start_server(
2615 Arc::new(ContextServer::new(
2616 ContextServerId(name.into()),
2617 Arc::new(fake_transport),
2618 )),
2619 cx,
2620 );
2621 });
2622 cx.run_until_parked();
2623 mcp_tool_calls_rx
2624}