1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use fs::{FakeFs, Fs};
11use futures::{
12 StreamExt,
13 channel::{
14 mpsc::{self, UnboundedReceiver},
15 oneshot,
16 },
17};
18use gpui::{
19 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
20};
21use indoc::indoc;
22use language_model::{
23 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
24 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
25 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
26 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
27};
28use pretty_assertions::assert_eq;
29use project::{
30 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
31};
32use prompt_store::ProjectContext;
33use reqwest_client::ReqwestClient;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use serde_json::json;
37use settings::{Settings, SettingsStore};
38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
39use util::path;
40
41mod test_tools;
42use test_tools::*;
43
44#[gpui::test]
45async fn test_echo(cx: &mut TestAppContext) {
46 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
47 let fake_model = model.as_fake();
48
49 let events = thread
50 .update(cx, |thread, cx| {
51 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
52 })
53 .unwrap();
54 cx.run_until_parked();
55 fake_model.send_last_completion_stream_text_chunk("Hello");
56 fake_model
57 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
58 fake_model.end_last_completion_stream();
59
60 let events = events.collect().await;
61 thread.update(cx, |thread, _cx| {
62 assert_eq!(
63 thread.last_message().unwrap().to_markdown(),
64 indoc! {"
65 ## Assistant
66
67 Hello
68 "}
69 )
70 });
71 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
72}
73
74#[gpui::test]
75async fn test_thinking(cx: &mut TestAppContext) {
76 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
77 let fake_model = model.as_fake();
78
79 let events = thread
80 .update(cx, |thread, cx| {
81 thread.send(
82 UserMessageId::new(),
83 [indoc! {"
84 Testing:
85
86 Generate a thinking step where you just think the word 'Think',
87 and have your final answer be 'Hello'
88 "}],
89 cx,
90 )
91 })
92 .unwrap();
93 cx.run_until_parked();
94 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
95 text: "Think".to_string(),
96 signature: None,
97 });
98 fake_model.send_last_completion_stream_text_chunk("Hello");
99 fake_model
100 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
101 fake_model.end_last_completion_stream();
102
103 let events = events.collect().await;
104 thread.update(cx, |thread, _cx| {
105 assert_eq!(
106 thread.last_message().unwrap().to_markdown(),
107 indoc! {"
108 ## Assistant
109
110 <think>Think</think>
111 Hello
112 "}
113 )
114 });
115 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
116}
117
118#[gpui::test]
119async fn test_system_prompt(cx: &mut TestAppContext) {
120 let ThreadTest {
121 model,
122 thread,
123 project_context,
124 ..
125 } = setup(cx, TestModel::Fake).await;
126 let fake_model = model.as_fake();
127
128 project_context.update(cx, |project_context, _cx| {
129 project_context.shell = "test-shell".into()
130 });
131 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
132 thread
133 .update(cx, |thread, cx| {
134 thread.send(UserMessageId::new(), ["abc"], cx)
135 })
136 .unwrap();
137 cx.run_until_parked();
138 let mut pending_completions = fake_model.pending_completions();
139 assert_eq!(
140 pending_completions.len(),
141 1,
142 "unexpected pending completions: {:?}",
143 pending_completions
144 );
145
146 let pending_completion = pending_completions.pop().unwrap();
147 assert_eq!(pending_completion.messages[0].role, Role::System);
148
149 let system_message = &pending_completion.messages[0];
150 let system_prompt = system_message.content[0].to_str().unwrap();
151 assert!(
152 system_prompt.contains("test-shell"),
153 "unexpected system message: {:?}",
154 system_message
155 );
156 assert!(
157 system_prompt.contains("## Fixing Diagnostics"),
158 "unexpected system message: {:?}",
159 system_message
160 );
161}
162
163#[gpui::test]
164async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
165 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
166 let fake_model = model.as_fake();
167
168 thread
169 .update(cx, |thread, cx| {
170 thread.send(UserMessageId::new(), ["abc"], cx)
171 })
172 .unwrap();
173 cx.run_until_parked();
174 let mut pending_completions = fake_model.pending_completions();
175 assert_eq!(
176 pending_completions.len(),
177 1,
178 "unexpected pending completions: {:?}",
179 pending_completions
180 );
181
182 let pending_completion = pending_completions.pop().unwrap();
183 assert_eq!(pending_completion.messages[0].role, Role::System);
184
185 let system_message = &pending_completion.messages[0];
186 let system_prompt = system_message.content[0].to_str().unwrap();
187 assert!(
188 !system_prompt.contains("## Tool Use"),
189 "unexpected system message: {:?}",
190 system_message
191 );
192 assert!(
193 !system_prompt.contains("## Fixing Diagnostics"),
194 "unexpected system message: {:?}",
195 system_message
196 );
197}
198
199#[gpui::test]
200async fn test_prompt_caching(cx: &mut TestAppContext) {
201 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
202 let fake_model = model.as_fake();
203
204 // Send initial user message and verify it's cached
205 thread
206 .update(cx, |thread, cx| {
207 thread.send(UserMessageId::new(), ["Message 1"], cx)
208 })
209 .unwrap();
210 cx.run_until_parked();
211
212 let completion = fake_model.pending_completions().pop().unwrap();
213 assert_eq!(
214 completion.messages[1..],
215 vec![LanguageModelRequestMessage {
216 role: Role::User,
217 content: vec!["Message 1".into()],
218 cache: true,
219 reasoning_details: None,
220 }]
221 );
222 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
223 "Response to Message 1".into(),
224 ));
225 fake_model.end_last_completion_stream();
226 cx.run_until_parked();
227
228 // Send another user message and verify only the latest is cached
229 thread
230 .update(cx, |thread, cx| {
231 thread.send(UserMessageId::new(), ["Message 2"], cx)
232 })
233 .unwrap();
234 cx.run_until_parked();
235
236 let completion = fake_model.pending_completions().pop().unwrap();
237 assert_eq!(
238 completion.messages[1..],
239 vec![
240 LanguageModelRequestMessage {
241 role: Role::User,
242 content: vec!["Message 1".into()],
243 cache: false,
244 reasoning_details: None,
245 },
246 LanguageModelRequestMessage {
247 role: Role::Assistant,
248 content: vec!["Response to Message 1".into()],
249 cache: false,
250 reasoning_details: None,
251 },
252 LanguageModelRequestMessage {
253 role: Role::User,
254 content: vec!["Message 2".into()],
255 cache: true,
256 reasoning_details: None,
257 }
258 ]
259 );
260 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
261 "Response to Message 2".into(),
262 ));
263 fake_model.end_last_completion_stream();
264 cx.run_until_parked();
265
266 // Simulate a tool call and verify that the latest tool result is cached
267 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
268 thread
269 .update(cx, |thread, cx| {
270 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
271 })
272 .unwrap();
273 cx.run_until_parked();
274
275 let tool_use = LanguageModelToolUse {
276 id: "tool_1".into(),
277 name: EchoTool::name().into(),
278 raw_input: json!({"text": "test"}).to_string(),
279 input: json!({"text": "test"}),
280 is_input_complete: true,
281 thought_signature: None,
282 };
283 fake_model
284 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
285 fake_model.end_last_completion_stream();
286 cx.run_until_parked();
287
288 let completion = fake_model.pending_completions().pop().unwrap();
289 let tool_result = LanguageModelToolResult {
290 tool_use_id: "tool_1".into(),
291 tool_name: EchoTool::name().into(),
292 is_error: false,
293 content: "test".into(),
294 output: Some("test".into()),
295 };
296 assert_eq!(
297 completion.messages[1..],
298 vec![
299 LanguageModelRequestMessage {
300 role: Role::User,
301 content: vec!["Message 1".into()],
302 cache: false,
303 reasoning_details: None,
304 },
305 LanguageModelRequestMessage {
306 role: Role::Assistant,
307 content: vec!["Response to Message 1".into()],
308 cache: false,
309 reasoning_details: None,
310 },
311 LanguageModelRequestMessage {
312 role: Role::User,
313 content: vec!["Message 2".into()],
314 cache: false,
315 reasoning_details: None,
316 },
317 LanguageModelRequestMessage {
318 role: Role::Assistant,
319 content: vec!["Response to Message 2".into()],
320 cache: false,
321 reasoning_details: None,
322 },
323 LanguageModelRequestMessage {
324 role: Role::User,
325 content: vec!["Use the echo tool".into()],
326 cache: false,
327 reasoning_details: None,
328 },
329 LanguageModelRequestMessage {
330 role: Role::Assistant,
331 content: vec![MessageContent::ToolUse(tool_use)],
332 cache: false,
333 reasoning_details: None,
334 },
335 LanguageModelRequestMessage {
336 role: Role::User,
337 content: vec![MessageContent::ToolResult(tool_result)],
338 cache: true,
339 reasoning_details: None,
340 }
341 ]
342 );
343}
344
345#[gpui::test]
346#[cfg_attr(not(feature = "e2e"), ignore)]
347async fn test_basic_tool_calls(cx: &mut TestAppContext) {
348 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
349
350 // Test a tool call that's likely to complete *before* streaming stops.
351 let events = thread
352 .update(cx, |thread, cx| {
353 thread.add_tool(EchoTool);
354 thread.send(
355 UserMessageId::new(),
356 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
357 cx,
358 )
359 })
360 .unwrap()
361 .collect()
362 .await;
363 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
364
365 // Test a tool calls that's likely to complete *after* streaming stops.
366 let events = thread
367 .update(cx, |thread, cx| {
368 thread.remove_tool(&EchoTool::name());
369 thread.add_tool(DelayTool);
370 thread.send(
371 UserMessageId::new(),
372 [
373 "Now call the delay tool with 200ms.",
374 "When the timer goes off, then you echo the output of the tool.",
375 ],
376 cx,
377 )
378 })
379 .unwrap()
380 .collect()
381 .await;
382 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
383 thread.update(cx, |thread, _cx| {
384 assert!(
385 thread
386 .last_message()
387 .unwrap()
388 .as_agent_message()
389 .unwrap()
390 .content
391 .iter()
392 .any(|content| {
393 if let AgentMessageContent::Text(text) = content {
394 text.contains("Ding")
395 } else {
396 false
397 }
398 }),
399 "{}",
400 thread.to_markdown()
401 );
402 });
403}
404
405#[gpui::test]
406#[cfg_attr(not(feature = "e2e"), ignore)]
407async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
408 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
409
410 // Test a tool call that's likely to complete *before* streaming stops.
411 let mut events = thread
412 .update(cx, |thread, cx| {
413 thread.add_tool(WordListTool);
414 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
415 })
416 .unwrap();
417
418 let mut saw_partial_tool_use = false;
419 while let Some(event) = events.next().await {
420 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
421 thread.update(cx, |thread, _cx| {
422 // Look for a tool use in the thread's last message
423 let message = thread.last_message().unwrap();
424 let agent_message = message.as_agent_message().unwrap();
425 let last_content = agent_message.content.last().unwrap();
426 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
427 assert_eq!(last_tool_use.name.as_ref(), "word_list");
428 if tool_call.status == acp::ToolCallStatus::Pending {
429 if !last_tool_use.is_input_complete
430 && last_tool_use.input.get("g").is_none()
431 {
432 saw_partial_tool_use = true;
433 }
434 } else {
435 last_tool_use
436 .input
437 .get("a")
438 .expect("'a' has streamed because input is now complete");
439 last_tool_use
440 .input
441 .get("g")
442 .expect("'g' has streamed because input is now complete");
443 }
444 } else {
445 panic!("last content should be a tool use");
446 }
447 });
448 }
449 }
450
451 assert!(
452 saw_partial_tool_use,
453 "should see at least one partially streamed tool use in the history"
454 );
455}
456
457#[gpui::test]
458async fn test_tool_authorization(cx: &mut TestAppContext) {
459 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
460 let fake_model = model.as_fake();
461
462 let mut events = thread
463 .update(cx, |thread, cx| {
464 thread.add_tool(ToolRequiringPermission);
465 thread.send(UserMessageId::new(), ["abc"], cx)
466 })
467 .unwrap();
468 cx.run_until_parked();
469 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
470 LanguageModelToolUse {
471 id: "tool_id_1".into(),
472 name: ToolRequiringPermission::name().into(),
473 raw_input: "{}".into(),
474 input: json!({}),
475 is_input_complete: true,
476 thought_signature: None,
477 },
478 ));
479 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
480 LanguageModelToolUse {
481 id: "tool_id_2".into(),
482 name: ToolRequiringPermission::name().into(),
483 raw_input: "{}".into(),
484 input: json!({}),
485 is_input_complete: true,
486 thought_signature: None,
487 },
488 ));
489 fake_model.end_last_completion_stream();
490 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
491 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
492
493 // Approve the first
494 tool_call_auth_1
495 .response
496 .send(tool_call_auth_1.options[1].id.clone())
497 .unwrap();
498 cx.run_until_parked();
499
500 // Reject the second
501 tool_call_auth_2
502 .response
503 .send(tool_call_auth_1.options[2].id.clone())
504 .unwrap();
505 cx.run_until_parked();
506
507 let completion = fake_model.pending_completions().pop().unwrap();
508 let message = completion.messages.last().unwrap();
509 assert_eq!(
510 message.content,
511 vec![
512 language_model::MessageContent::ToolResult(LanguageModelToolResult {
513 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
514 tool_name: ToolRequiringPermission::name().into(),
515 is_error: false,
516 content: "Allowed".into(),
517 output: Some("Allowed".into())
518 }),
519 language_model::MessageContent::ToolResult(LanguageModelToolResult {
520 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
521 tool_name: ToolRequiringPermission::name().into(),
522 is_error: true,
523 content: "Permission to run tool denied by user".into(),
524 output: Some("Permission to run tool denied by user".into())
525 })
526 ]
527 );
528
529 // Simulate yet another tool call.
530 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
531 LanguageModelToolUse {
532 id: "tool_id_3".into(),
533 name: ToolRequiringPermission::name().into(),
534 raw_input: "{}".into(),
535 input: json!({}),
536 is_input_complete: true,
537 thought_signature: None,
538 },
539 ));
540 fake_model.end_last_completion_stream();
541
542 // Respond by always allowing tools.
543 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
544 tool_call_auth_3
545 .response
546 .send(tool_call_auth_3.options[0].id.clone())
547 .unwrap();
548 cx.run_until_parked();
549 let completion = fake_model.pending_completions().pop().unwrap();
550 let message = completion.messages.last().unwrap();
551 assert_eq!(
552 message.content,
553 vec![language_model::MessageContent::ToolResult(
554 LanguageModelToolResult {
555 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
556 tool_name: ToolRequiringPermission::name().into(),
557 is_error: false,
558 content: "Allowed".into(),
559 output: Some("Allowed".into())
560 }
561 )]
562 );
563
564 // Simulate a final tool call, ensuring we don't trigger authorization.
565 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
566 LanguageModelToolUse {
567 id: "tool_id_4".into(),
568 name: ToolRequiringPermission::name().into(),
569 raw_input: "{}".into(),
570 input: json!({}),
571 is_input_complete: true,
572 thought_signature: None,
573 },
574 ));
575 fake_model.end_last_completion_stream();
576 cx.run_until_parked();
577 let completion = fake_model.pending_completions().pop().unwrap();
578 let message = completion.messages.last().unwrap();
579 assert_eq!(
580 message.content,
581 vec![language_model::MessageContent::ToolResult(
582 LanguageModelToolResult {
583 tool_use_id: "tool_id_4".into(),
584 tool_name: ToolRequiringPermission::name().into(),
585 is_error: false,
586 content: "Allowed".into(),
587 output: Some("Allowed".into())
588 }
589 )]
590 );
591}
592
593#[gpui::test]
594async fn test_tool_hallucination(cx: &mut TestAppContext) {
595 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
596 let fake_model = model.as_fake();
597
598 let mut events = thread
599 .update(cx, |thread, cx| {
600 thread.send(UserMessageId::new(), ["abc"], cx)
601 })
602 .unwrap();
603 cx.run_until_parked();
604 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
605 LanguageModelToolUse {
606 id: "tool_id_1".into(),
607 name: "nonexistent_tool".into(),
608 raw_input: "{}".into(),
609 input: json!({}),
610 is_input_complete: true,
611 thought_signature: None,
612 },
613 ));
614 fake_model.end_last_completion_stream();
615
616 let tool_call = expect_tool_call(&mut events).await;
617 assert_eq!(tool_call.title, "nonexistent_tool");
618 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
619 let update = expect_tool_call_update_fields(&mut events).await;
620 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
621}
622
623#[gpui::test]
624async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
625 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
626 let fake_model = model.as_fake();
627
628 let events = thread
629 .update(cx, |thread, cx| {
630 thread.add_tool(EchoTool);
631 thread.send(UserMessageId::new(), ["abc"], cx)
632 })
633 .unwrap();
634 cx.run_until_parked();
635 let tool_use = LanguageModelToolUse {
636 id: "tool_id_1".into(),
637 name: EchoTool::name().into(),
638 raw_input: "{}".into(),
639 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
640 is_input_complete: true,
641 thought_signature: None,
642 };
643 fake_model
644 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
645 fake_model.end_last_completion_stream();
646
647 cx.run_until_parked();
648 let completion = fake_model.pending_completions().pop().unwrap();
649 let tool_result = LanguageModelToolResult {
650 tool_use_id: "tool_id_1".into(),
651 tool_name: EchoTool::name().into(),
652 is_error: false,
653 content: "def".into(),
654 output: Some("def".into()),
655 };
656 assert_eq!(
657 completion.messages[1..],
658 vec![
659 LanguageModelRequestMessage {
660 role: Role::User,
661 content: vec!["abc".into()],
662 cache: false,
663 reasoning_details: None,
664 },
665 LanguageModelRequestMessage {
666 role: Role::Assistant,
667 content: vec![MessageContent::ToolUse(tool_use.clone())],
668 cache: false,
669 reasoning_details: None,
670 },
671 LanguageModelRequestMessage {
672 role: Role::User,
673 content: vec![MessageContent::ToolResult(tool_result.clone())],
674 cache: true,
675 reasoning_details: None,
676 },
677 ]
678 );
679
680 // Simulate reaching tool use limit.
681 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
682 fake_model.end_last_completion_stream();
683 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
684 assert!(
685 last_event
686 .unwrap_err()
687 .is::<language_model::ToolUseLimitReachedError>()
688 );
689
690 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
691 cx.run_until_parked();
692 let completion = fake_model.pending_completions().pop().unwrap();
693 assert_eq!(
694 completion.messages[1..],
695 vec![
696 LanguageModelRequestMessage {
697 role: Role::User,
698 content: vec!["abc".into()],
699 cache: false,
700 reasoning_details: None,
701 },
702 LanguageModelRequestMessage {
703 role: Role::Assistant,
704 content: vec![MessageContent::ToolUse(tool_use)],
705 cache: false,
706 reasoning_details: None,
707 },
708 LanguageModelRequestMessage {
709 role: Role::User,
710 content: vec![MessageContent::ToolResult(tool_result)],
711 cache: false,
712 reasoning_details: None,
713 },
714 LanguageModelRequestMessage {
715 role: Role::User,
716 content: vec!["Continue where you left off".into()],
717 cache: true,
718 reasoning_details: None,
719 }
720 ]
721 );
722
723 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
724 fake_model.end_last_completion_stream();
725 events.collect::<Vec<_>>().await;
726 thread.read_with(cx, |thread, _cx| {
727 assert_eq!(
728 thread.last_message().unwrap().to_markdown(),
729 indoc! {"
730 ## Assistant
731
732 Done
733 "}
734 )
735 });
736}
737
738#[gpui::test]
739async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
740 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
741 let fake_model = model.as_fake();
742
743 let events = thread
744 .update(cx, |thread, cx| {
745 thread.add_tool(EchoTool);
746 thread.send(UserMessageId::new(), ["abc"], cx)
747 })
748 .unwrap();
749 cx.run_until_parked();
750
751 let tool_use = LanguageModelToolUse {
752 id: "tool_id_1".into(),
753 name: EchoTool::name().into(),
754 raw_input: "{}".into(),
755 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
756 is_input_complete: true,
757 thought_signature: None,
758 };
759 let tool_result = LanguageModelToolResult {
760 tool_use_id: "tool_id_1".into(),
761 tool_name: EchoTool::name().into(),
762 is_error: false,
763 content: "def".into(),
764 output: Some("def".into()),
765 };
766 fake_model
767 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
768 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
769 fake_model.end_last_completion_stream();
770 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
771 assert!(
772 last_event
773 .unwrap_err()
774 .is::<language_model::ToolUseLimitReachedError>()
775 );
776
777 thread
778 .update(cx, |thread, cx| {
779 thread.send(UserMessageId::new(), vec!["ghi"], cx)
780 })
781 .unwrap();
782 cx.run_until_parked();
783 let completion = fake_model.pending_completions().pop().unwrap();
784 assert_eq!(
785 completion.messages[1..],
786 vec![
787 LanguageModelRequestMessage {
788 role: Role::User,
789 content: vec!["abc".into()],
790 cache: false,
791 reasoning_details: None,
792 },
793 LanguageModelRequestMessage {
794 role: Role::Assistant,
795 content: vec![MessageContent::ToolUse(tool_use)],
796 cache: false,
797 reasoning_details: None,
798 },
799 LanguageModelRequestMessage {
800 role: Role::User,
801 content: vec![MessageContent::ToolResult(tool_result)],
802 cache: false,
803 reasoning_details: None,
804 },
805 LanguageModelRequestMessage {
806 role: Role::User,
807 content: vec!["ghi".into()],
808 cache: true,
809 reasoning_details: None,
810 }
811 ]
812 );
813}
814
815async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
816 let event = events
817 .next()
818 .await
819 .expect("no tool call authorization event received")
820 .unwrap();
821 match event {
822 ThreadEvent::ToolCall(tool_call) => tool_call,
823 event => {
824 panic!("Unexpected event {event:?}");
825 }
826 }
827}
828
829async fn expect_tool_call_update_fields(
830 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
831) -> acp::ToolCallUpdate {
832 let event = events
833 .next()
834 .await
835 .expect("no tool call authorization event received")
836 .unwrap();
837 match event {
838 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
839 event => {
840 panic!("Unexpected event {event:?}");
841 }
842 }
843}
844
845async fn next_tool_call_authorization(
846 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
847) -> ToolCallAuthorization {
848 loop {
849 let event = events
850 .next()
851 .await
852 .expect("no tool call authorization event received")
853 .unwrap();
854 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
855 let permission_kinds = tool_call_authorization
856 .options
857 .iter()
858 .map(|o| o.kind)
859 .collect::<Vec<_>>();
860 assert_eq!(
861 permission_kinds,
862 vec![
863 acp::PermissionOptionKind::AllowAlways,
864 acp::PermissionOptionKind::AllowOnce,
865 acp::PermissionOptionKind::RejectOnce,
866 ]
867 );
868 return tool_call_authorization;
869 }
870 }
871}
872
873#[gpui::test]
874#[cfg_attr(not(feature = "e2e"), ignore)]
875async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
876 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
877
878 // Test concurrent tool calls with different delay times
879 let events = thread
880 .update(cx, |thread, cx| {
881 thread.add_tool(DelayTool);
882 thread.send(
883 UserMessageId::new(),
884 [
885 "Call the delay tool twice in the same message.",
886 "Once with 100ms. Once with 300ms.",
887 "When both timers are complete, describe the outputs.",
888 ],
889 cx,
890 )
891 })
892 .unwrap()
893 .collect()
894 .await;
895
896 let stop_reasons = stop_events(events);
897 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
898
899 thread.update(cx, |thread, _cx| {
900 let last_message = thread.last_message().unwrap();
901 let agent_message = last_message.as_agent_message().unwrap();
902 let text = agent_message
903 .content
904 .iter()
905 .filter_map(|content| {
906 if let AgentMessageContent::Text(text) = content {
907 Some(text.as_str())
908 } else {
909 None
910 }
911 })
912 .collect::<String>();
913
914 assert!(text.contains("Ding"));
915 });
916}
917
918#[gpui::test]
919async fn test_profiles(cx: &mut TestAppContext) {
920 let ThreadTest {
921 model, thread, fs, ..
922 } = setup(cx, TestModel::Fake).await;
923 let fake_model = model.as_fake();
924
925 thread.update(cx, |thread, _cx| {
926 thread.add_tool(DelayTool);
927 thread.add_tool(EchoTool);
928 thread.add_tool(InfiniteTool);
929 });
930
931 // Override profiles and wait for settings to be loaded.
932 fs.insert_file(
933 paths::settings_file(),
934 json!({
935 "agent": {
936 "profiles": {
937 "test-1": {
938 "name": "Test Profile 1",
939 "tools": {
940 EchoTool::name(): true,
941 DelayTool::name(): true,
942 }
943 },
944 "test-2": {
945 "name": "Test Profile 2",
946 "tools": {
947 InfiniteTool::name(): true,
948 }
949 }
950 }
951 }
952 })
953 .to_string()
954 .into_bytes(),
955 )
956 .await;
957 cx.run_until_parked();
958
959 // Test that test-1 profile (default) has echo and delay tools
960 thread
961 .update(cx, |thread, cx| {
962 thread.set_profile(AgentProfileId("test-1".into()), cx);
963 thread.send(UserMessageId::new(), ["test"], cx)
964 })
965 .unwrap();
966 cx.run_until_parked();
967
968 let mut pending_completions = fake_model.pending_completions();
969 assert_eq!(pending_completions.len(), 1);
970 let completion = pending_completions.pop().unwrap();
971 let tool_names: Vec<String> = completion
972 .tools
973 .iter()
974 .map(|tool| tool.name.clone())
975 .collect();
976 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
977 fake_model.end_last_completion_stream();
978
979 // Switch to test-2 profile, and verify that it has only the infinite tool.
980 thread
981 .update(cx, |thread, cx| {
982 thread.set_profile(AgentProfileId("test-2".into()), cx);
983 thread.send(UserMessageId::new(), ["test2"], cx)
984 })
985 .unwrap();
986 cx.run_until_parked();
987 let mut pending_completions = fake_model.pending_completions();
988 assert_eq!(pending_completions.len(), 1);
989 let completion = pending_completions.pop().unwrap();
990 let tool_names: Vec<String> = completion
991 .tools
992 .iter()
993 .map(|tool| tool.name.clone())
994 .collect();
995 assert_eq!(tool_names, vec![InfiniteTool::name()]);
996}
997
998#[gpui::test]
999async fn test_mcp_tools(cx: &mut TestAppContext) {
1000 let ThreadTest {
1001 model,
1002 thread,
1003 context_server_store,
1004 fs,
1005 ..
1006 } = setup(cx, TestModel::Fake).await;
1007 let fake_model = model.as_fake();
1008
1009 // Override profiles and wait for settings to be loaded.
1010 fs.insert_file(
1011 paths::settings_file(),
1012 json!({
1013 "agent": {
1014 "always_allow_tool_actions": true,
1015 "profiles": {
1016 "test": {
1017 "name": "Test Profile",
1018 "enable_all_context_servers": true,
1019 "tools": {
1020 EchoTool::name(): true,
1021 }
1022 },
1023 }
1024 }
1025 })
1026 .to_string()
1027 .into_bytes(),
1028 )
1029 .await;
1030 cx.run_until_parked();
1031 thread.update(cx, |thread, cx| {
1032 thread.set_profile(AgentProfileId("test".into()), cx)
1033 });
1034
1035 let mut mcp_tool_calls = setup_context_server(
1036 "test_server",
1037 vec![context_server::types::Tool {
1038 name: "echo".into(),
1039 description: None,
1040 input_schema: serde_json::to_value(EchoTool::input_schema(
1041 LanguageModelToolSchemaFormat::JsonSchema,
1042 ))
1043 .unwrap(),
1044 output_schema: None,
1045 annotations: None,
1046 }],
1047 &context_server_store,
1048 cx,
1049 );
1050
1051 let events = thread.update(cx, |thread, cx| {
1052 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1053 });
1054 cx.run_until_parked();
1055
1056 // Simulate the model calling the MCP tool.
1057 let completion = fake_model.pending_completions().pop().unwrap();
1058 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1059 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1060 LanguageModelToolUse {
1061 id: "tool_1".into(),
1062 name: "echo".into(),
1063 raw_input: json!({"text": "test"}).to_string(),
1064 input: json!({"text": "test"}),
1065 is_input_complete: true,
1066 thought_signature: None,
1067 },
1068 ));
1069 fake_model.end_last_completion_stream();
1070 cx.run_until_parked();
1071
1072 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1073 assert_eq!(tool_call_params.name, "echo");
1074 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1075 tool_call_response
1076 .send(context_server::types::CallToolResponse {
1077 content: vec![context_server::types::ToolResponseContent::Text {
1078 text: "test".into(),
1079 }],
1080 is_error: None,
1081 meta: None,
1082 structured_content: None,
1083 })
1084 .unwrap();
1085 cx.run_until_parked();
1086
1087 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1088 fake_model.send_last_completion_stream_text_chunk("Done!");
1089 fake_model.end_last_completion_stream();
1090 events.collect::<Vec<_>>().await;
1091
1092 // Send again after adding the echo tool, ensuring the name collision is resolved.
1093 let events = thread.update(cx, |thread, cx| {
1094 thread.add_tool(EchoTool);
1095 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1096 });
1097 cx.run_until_parked();
1098 let completion = fake_model.pending_completions().pop().unwrap();
1099 assert_eq!(
1100 tool_names_for_completion(&completion),
1101 vec!["echo", "test_server_echo"]
1102 );
1103 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1104 LanguageModelToolUse {
1105 id: "tool_2".into(),
1106 name: "test_server_echo".into(),
1107 raw_input: json!({"text": "mcp"}).to_string(),
1108 input: json!({"text": "mcp"}),
1109 is_input_complete: true,
1110 thought_signature: None,
1111 },
1112 ));
1113 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1114 LanguageModelToolUse {
1115 id: "tool_3".into(),
1116 name: "echo".into(),
1117 raw_input: json!({"text": "native"}).to_string(),
1118 input: json!({"text": "native"}),
1119 is_input_complete: true,
1120 thought_signature: None,
1121 },
1122 ));
1123 fake_model.end_last_completion_stream();
1124 cx.run_until_parked();
1125
1126 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1127 assert_eq!(tool_call_params.name, "echo");
1128 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1129 tool_call_response
1130 .send(context_server::types::CallToolResponse {
1131 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1132 is_error: None,
1133 meta: None,
1134 structured_content: None,
1135 })
1136 .unwrap();
1137 cx.run_until_parked();
1138
1139 // Ensure the tool results were inserted with the correct names.
1140 let completion = fake_model.pending_completions().pop().unwrap();
1141 assert_eq!(
1142 completion.messages.last().unwrap().content,
1143 vec![
1144 MessageContent::ToolResult(LanguageModelToolResult {
1145 tool_use_id: "tool_3".into(),
1146 tool_name: "echo".into(),
1147 is_error: false,
1148 content: "native".into(),
1149 output: Some("native".into()),
1150 },),
1151 MessageContent::ToolResult(LanguageModelToolResult {
1152 tool_use_id: "tool_2".into(),
1153 tool_name: "test_server_echo".into(),
1154 is_error: false,
1155 content: "mcp".into(),
1156 output: Some("mcp".into()),
1157 },),
1158 ]
1159 );
1160 fake_model.end_last_completion_stream();
1161 events.collect::<Vec<_>>().await;
1162}
1163
1164#[gpui::test]
1165async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1166 let ThreadTest {
1167 model,
1168 thread,
1169 context_server_store,
1170 fs,
1171 ..
1172 } = setup(cx, TestModel::Fake).await;
1173 let fake_model = model.as_fake();
1174
1175 // Set up a profile with all tools enabled
1176 fs.insert_file(
1177 paths::settings_file(),
1178 json!({
1179 "agent": {
1180 "profiles": {
1181 "test": {
1182 "name": "Test Profile",
1183 "enable_all_context_servers": true,
1184 "tools": {
1185 EchoTool::name(): true,
1186 DelayTool::name(): true,
1187 WordListTool::name(): true,
1188 ToolRequiringPermission::name(): true,
1189 InfiniteTool::name(): true,
1190 }
1191 },
1192 }
1193 }
1194 })
1195 .to_string()
1196 .into_bytes(),
1197 )
1198 .await;
1199 cx.run_until_parked();
1200
1201 thread.update(cx, |thread, cx| {
1202 thread.set_profile(AgentProfileId("test".into()), cx);
1203 thread.add_tool(EchoTool);
1204 thread.add_tool(DelayTool);
1205 thread.add_tool(WordListTool);
1206 thread.add_tool(ToolRequiringPermission);
1207 thread.add_tool(InfiniteTool);
1208 });
1209
1210 // Set up multiple context servers with some overlapping tool names
1211 let _server1_calls = setup_context_server(
1212 "xxx",
1213 vec![
1214 context_server::types::Tool {
1215 name: "echo".into(), // Conflicts with native EchoTool
1216 description: None,
1217 input_schema: serde_json::to_value(EchoTool::input_schema(
1218 LanguageModelToolSchemaFormat::JsonSchema,
1219 ))
1220 .unwrap(),
1221 output_schema: None,
1222 annotations: None,
1223 },
1224 context_server::types::Tool {
1225 name: "unique_tool_1".into(),
1226 description: None,
1227 input_schema: json!({"type": "object", "properties": {}}),
1228 output_schema: None,
1229 annotations: None,
1230 },
1231 ],
1232 &context_server_store,
1233 cx,
1234 );
1235
1236 let _server2_calls = setup_context_server(
1237 "yyy",
1238 vec![
1239 context_server::types::Tool {
1240 name: "echo".into(), // Also conflicts with native EchoTool
1241 description: None,
1242 input_schema: serde_json::to_value(EchoTool::input_schema(
1243 LanguageModelToolSchemaFormat::JsonSchema,
1244 ))
1245 .unwrap(),
1246 output_schema: None,
1247 annotations: None,
1248 },
1249 context_server::types::Tool {
1250 name: "unique_tool_2".into(),
1251 description: None,
1252 input_schema: json!({"type": "object", "properties": {}}),
1253 output_schema: None,
1254 annotations: None,
1255 },
1256 context_server::types::Tool {
1257 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1258 description: None,
1259 input_schema: json!({"type": "object", "properties": {}}),
1260 output_schema: None,
1261 annotations: None,
1262 },
1263 context_server::types::Tool {
1264 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1265 description: None,
1266 input_schema: json!({"type": "object", "properties": {}}),
1267 output_schema: None,
1268 annotations: None,
1269 },
1270 ],
1271 &context_server_store,
1272 cx,
1273 );
1274 let _server3_calls = setup_context_server(
1275 "zzz",
1276 vec![
1277 context_server::types::Tool {
1278 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1279 description: None,
1280 input_schema: json!({"type": "object", "properties": {}}),
1281 output_schema: None,
1282 annotations: None,
1283 },
1284 context_server::types::Tool {
1285 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1286 description: None,
1287 input_schema: json!({"type": "object", "properties": {}}),
1288 output_schema: None,
1289 annotations: None,
1290 },
1291 context_server::types::Tool {
1292 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1293 description: None,
1294 input_schema: json!({"type": "object", "properties": {}}),
1295 output_schema: None,
1296 annotations: None,
1297 },
1298 ],
1299 &context_server_store,
1300 cx,
1301 );
1302
1303 thread
1304 .update(cx, |thread, cx| {
1305 thread.send(UserMessageId::new(), ["Go"], cx)
1306 })
1307 .unwrap();
1308 cx.run_until_parked();
1309 let completion = fake_model.pending_completions().pop().unwrap();
1310 assert_eq!(
1311 tool_names_for_completion(&completion),
1312 vec![
1313 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1314 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1315 "delay",
1316 "echo",
1317 "infinite",
1318 "tool_requiring_permission",
1319 "unique_tool_1",
1320 "unique_tool_2",
1321 "word_list",
1322 "xxx_echo",
1323 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1324 "yyy_echo",
1325 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1326 ]
1327 );
1328}
1329
1330#[gpui::test]
1331#[cfg_attr(not(feature = "e2e"), ignore)]
1332async fn test_cancellation(cx: &mut TestAppContext) {
1333 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1334
1335 let mut events = thread
1336 .update(cx, |thread, cx| {
1337 thread.add_tool(InfiniteTool);
1338 thread.add_tool(EchoTool);
1339 thread.send(
1340 UserMessageId::new(),
1341 ["Call the echo tool, then call the infinite tool, then explain their output"],
1342 cx,
1343 )
1344 })
1345 .unwrap();
1346
1347 // Wait until both tools are called.
1348 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1349 let mut echo_id = None;
1350 let mut echo_completed = false;
1351 while let Some(event) = events.next().await {
1352 match event.unwrap() {
1353 ThreadEvent::ToolCall(tool_call) => {
1354 assert_eq!(tool_call.title, expected_tools.remove(0));
1355 if tool_call.title == "Echo" {
1356 echo_id = Some(tool_call.id);
1357 }
1358 }
1359 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1360 acp::ToolCallUpdate {
1361 id,
1362 fields:
1363 acp::ToolCallUpdateFields {
1364 status: Some(acp::ToolCallStatus::Completed),
1365 ..
1366 },
1367 meta: None,
1368 },
1369 )) if Some(&id) == echo_id.as_ref() => {
1370 echo_completed = true;
1371 }
1372 _ => {}
1373 }
1374
1375 if expected_tools.is_empty() && echo_completed {
1376 break;
1377 }
1378 }
1379
1380 // Cancel the current send and ensure that the event stream is closed, even
1381 // if one of the tools is still running.
1382 thread.update(cx, |thread, cx| thread.cancel(cx));
1383 let events = events.collect::<Vec<_>>().await;
1384 let last_event = events.last();
1385 assert!(
1386 matches!(
1387 last_event,
1388 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1389 ),
1390 "unexpected event {last_event:?}"
1391 );
1392
1393 // Ensure we can still send a new message after cancellation.
1394 let events = thread
1395 .update(cx, |thread, cx| {
1396 thread.send(
1397 UserMessageId::new(),
1398 ["Testing: reply with 'Hello' then stop."],
1399 cx,
1400 )
1401 })
1402 .unwrap()
1403 .collect::<Vec<_>>()
1404 .await;
1405 thread.update(cx, |thread, _cx| {
1406 let message = thread.last_message().unwrap();
1407 let agent_message = message.as_agent_message().unwrap();
1408 assert_eq!(
1409 agent_message.content,
1410 vec![AgentMessageContent::Text("Hello".to_string())]
1411 );
1412 });
1413 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1414}
1415
1416#[gpui::test]
1417async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1418 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1419 let fake_model = model.as_fake();
1420
1421 let events_1 = thread
1422 .update(cx, |thread, cx| {
1423 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1424 })
1425 .unwrap();
1426 cx.run_until_parked();
1427 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1428 cx.run_until_parked();
1429
1430 let events_2 = thread
1431 .update(cx, |thread, cx| {
1432 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1433 })
1434 .unwrap();
1435 cx.run_until_parked();
1436 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1437 fake_model
1438 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1439 fake_model.end_last_completion_stream();
1440
1441 let events_1 = events_1.collect::<Vec<_>>().await;
1442 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1443 let events_2 = events_2.collect::<Vec<_>>().await;
1444 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1445}
1446
1447#[gpui::test]
1448async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1449 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1450 let fake_model = model.as_fake();
1451
1452 let events_1 = thread
1453 .update(cx, |thread, cx| {
1454 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1455 })
1456 .unwrap();
1457 cx.run_until_parked();
1458 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1459 fake_model
1460 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1461 fake_model.end_last_completion_stream();
1462 let events_1 = events_1.collect::<Vec<_>>().await;
1463
1464 let events_2 = thread
1465 .update(cx, |thread, cx| {
1466 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1467 })
1468 .unwrap();
1469 cx.run_until_parked();
1470 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1471 fake_model
1472 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1473 fake_model.end_last_completion_stream();
1474 let events_2 = events_2.collect::<Vec<_>>().await;
1475
1476 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1477 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1478}
1479
1480#[gpui::test]
1481async fn test_refusal(cx: &mut TestAppContext) {
1482 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1483 let fake_model = model.as_fake();
1484
1485 let events = thread
1486 .update(cx, |thread, cx| {
1487 thread.send(UserMessageId::new(), ["Hello"], cx)
1488 })
1489 .unwrap();
1490 cx.run_until_parked();
1491 thread.read_with(cx, |thread, _| {
1492 assert_eq!(
1493 thread.to_markdown(),
1494 indoc! {"
1495 ## User
1496
1497 Hello
1498 "}
1499 );
1500 });
1501
1502 fake_model.send_last_completion_stream_text_chunk("Hey!");
1503 cx.run_until_parked();
1504 thread.read_with(cx, |thread, _| {
1505 assert_eq!(
1506 thread.to_markdown(),
1507 indoc! {"
1508 ## User
1509
1510 Hello
1511
1512 ## Assistant
1513
1514 Hey!
1515 "}
1516 );
1517 });
1518
1519 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1520 fake_model
1521 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1522 let events = events.collect::<Vec<_>>().await;
1523 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1524 thread.read_with(cx, |thread, _| {
1525 assert_eq!(thread.to_markdown(), "");
1526 });
1527}
1528
1529#[gpui::test]
1530async fn test_truncate_first_message(cx: &mut TestAppContext) {
1531 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1532 let fake_model = model.as_fake();
1533
1534 let message_id = UserMessageId::new();
1535 thread
1536 .update(cx, |thread, cx| {
1537 thread.send(message_id.clone(), ["Hello"], cx)
1538 })
1539 .unwrap();
1540 cx.run_until_parked();
1541 thread.read_with(cx, |thread, _| {
1542 assert_eq!(
1543 thread.to_markdown(),
1544 indoc! {"
1545 ## User
1546
1547 Hello
1548 "}
1549 );
1550 assert_eq!(thread.latest_token_usage(), None);
1551 });
1552
1553 fake_model.send_last_completion_stream_text_chunk("Hey!");
1554 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1555 language_model::TokenUsage {
1556 input_tokens: 32_000,
1557 output_tokens: 16_000,
1558 cache_creation_input_tokens: 0,
1559 cache_read_input_tokens: 0,
1560 },
1561 ));
1562 cx.run_until_parked();
1563 thread.read_with(cx, |thread, _| {
1564 assert_eq!(
1565 thread.to_markdown(),
1566 indoc! {"
1567 ## User
1568
1569 Hello
1570
1571 ## Assistant
1572
1573 Hey!
1574 "}
1575 );
1576 assert_eq!(
1577 thread.latest_token_usage(),
1578 Some(acp_thread::TokenUsage {
1579 used_tokens: 32_000 + 16_000,
1580 max_tokens: 1_000_000,
1581 })
1582 );
1583 });
1584
1585 thread
1586 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1587 .unwrap();
1588 cx.run_until_parked();
1589 thread.read_with(cx, |thread, _| {
1590 assert_eq!(thread.to_markdown(), "");
1591 assert_eq!(thread.latest_token_usage(), None);
1592 });
1593
1594 // Ensure we can still send a new message after truncation.
1595 thread
1596 .update(cx, |thread, cx| {
1597 thread.send(UserMessageId::new(), ["Hi"], cx)
1598 })
1599 .unwrap();
1600 thread.update(cx, |thread, _cx| {
1601 assert_eq!(
1602 thread.to_markdown(),
1603 indoc! {"
1604 ## User
1605
1606 Hi
1607 "}
1608 );
1609 });
1610 cx.run_until_parked();
1611 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1612 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1613 language_model::TokenUsage {
1614 input_tokens: 40_000,
1615 output_tokens: 20_000,
1616 cache_creation_input_tokens: 0,
1617 cache_read_input_tokens: 0,
1618 },
1619 ));
1620 cx.run_until_parked();
1621 thread.read_with(cx, |thread, _| {
1622 assert_eq!(
1623 thread.to_markdown(),
1624 indoc! {"
1625 ## User
1626
1627 Hi
1628
1629 ## Assistant
1630
1631 Ahoy!
1632 "}
1633 );
1634
1635 assert_eq!(
1636 thread.latest_token_usage(),
1637 Some(acp_thread::TokenUsage {
1638 used_tokens: 40_000 + 20_000,
1639 max_tokens: 1_000_000,
1640 })
1641 );
1642 });
1643}
1644
1645#[gpui::test]
1646async fn test_truncate_second_message(cx: &mut TestAppContext) {
1647 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1648 let fake_model = model.as_fake();
1649
1650 thread
1651 .update(cx, |thread, cx| {
1652 thread.send(UserMessageId::new(), ["Message 1"], cx)
1653 })
1654 .unwrap();
1655 cx.run_until_parked();
1656 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1657 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1658 language_model::TokenUsage {
1659 input_tokens: 32_000,
1660 output_tokens: 16_000,
1661 cache_creation_input_tokens: 0,
1662 cache_read_input_tokens: 0,
1663 },
1664 ));
1665 fake_model.end_last_completion_stream();
1666 cx.run_until_parked();
1667
1668 let assert_first_message_state = |cx: &mut TestAppContext| {
1669 thread.clone().read_with(cx, |thread, _| {
1670 assert_eq!(
1671 thread.to_markdown(),
1672 indoc! {"
1673 ## User
1674
1675 Message 1
1676
1677 ## Assistant
1678
1679 Message 1 response
1680 "}
1681 );
1682
1683 assert_eq!(
1684 thread.latest_token_usage(),
1685 Some(acp_thread::TokenUsage {
1686 used_tokens: 32_000 + 16_000,
1687 max_tokens: 1_000_000,
1688 })
1689 );
1690 });
1691 };
1692
1693 assert_first_message_state(cx);
1694
1695 let second_message_id = UserMessageId::new();
1696 thread
1697 .update(cx, |thread, cx| {
1698 thread.send(second_message_id.clone(), ["Message 2"], cx)
1699 })
1700 .unwrap();
1701 cx.run_until_parked();
1702
1703 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1704 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1705 language_model::TokenUsage {
1706 input_tokens: 40_000,
1707 output_tokens: 20_000,
1708 cache_creation_input_tokens: 0,
1709 cache_read_input_tokens: 0,
1710 },
1711 ));
1712 fake_model.end_last_completion_stream();
1713 cx.run_until_parked();
1714
1715 thread.read_with(cx, |thread, _| {
1716 assert_eq!(
1717 thread.to_markdown(),
1718 indoc! {"
1719 ## User
1720
1721 Message 1
1722
1723 ## Assistant
1724
1725 Message 1 response
1726
1727 ## User
1728
1729 Message 2
1730
1731 ## Assistant
1732
1733 Message 2 response
1734 "}
1735 );
1736
1737 assert_eq!(
1738 thread.latest_token_usage(),
1739 Some(acp_thread::TokenUsage {
1740 used_tokens: 40_000 + 20_000,
1741 max_tokens: 1_000_000,
1742 })
1743 );
1744 });
1745
1746 thread
1747 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1748 .unwrap();
1749 cx.run_until_parked();
1750
1751 assert_first_message_state(cx);
1752}
1753
1754#[gpui::test]
1755async fn test_title_generation(cx: &mut TestAppContext) {
1756 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1757 let fake_model = model.as_fake();
1758
1759 let summary_model = Arc::new(FakeLanguageModel::default());
1760 thread.update(cx, |thread, cx| {
1761 thread.set_summarization_model(Some(summary_model.clone()), cx)
1762 });
1763
1764 let send = thread
1765 .update(cx, |thread, cx| {
1766 thread.send(UserMessageId::new(), ["Hello"], cx)
1767 })
1768 .unwrap();
1769 cx.run_until_parked();
1770
1771 fake_model.send_last_completion_stream_text_chunk("Hey!");
1772 fake_model.end_last_completion_stream();
1773 cx.run_until_parked();
1774 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1775
1776 // Ensure the summary model has been invoked to generate a title.
1777 summary_model.send_last_completion_stream_text_chunk("Hello ");
1778 summary_model.send_last_completion_stream_text_chunk("world\nG");
1779 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1780 summary_model.end_last_completion_stream();
1781 send.collect::<Vec<_>>().await;
1782 cx.run_until_parked();
1783 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1784
1785 // Send another message, ensuring no title is generated this time.
1786 let send = thread
1787 .update(cx, |thread, cx| {
1788 thread.send(UserMessageId::new(), ["Hello again"], cx)
1789 })
1790 .unwrap();
1791 cx.run_until_parked();
1792 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1793 fake_model.end_last_completion_stream();
1794 cx.run_until_parked();
1795 assert_eq!(summary_model.pending_completions(), Vec::new());
1796 send.collect::<Vec<_>>().await;
1797 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1798}
1799
1800#[gpui::test]
1801async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1802 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1803 let fake_model = model.as_fake();
1804
1805 let _events = thread
1806 .update(cx, |thread, cx| {
1807 thread.add_tool(ToolRequiringPermission);
1808 thread.add_tool(EchoTool);
1809 thread.send(UserMessageId::new(), ["Hey!"], cx)
1810 })
1811 .unwrap();
1812 cx.run_until_parked();
1813
1814 let permission_tool_use = LanguageModelToolUse {
1815 id: "tool_id_1".into(),
1816 name: ToolRequiringPermission::name().into(),
1817 raw_input: "{}".into(),
1818 input: json!({}),
1819 is_input_complete: true,
1820 thought_signature: None,
1821 };
1822 let echo_tool_use = LanguageModelToolUse {
1823 id: "tool_id_2".into(),
1824 name: EchoTool::name().into(),
1825 raw_input: json!({"text": "test"}).to_string(),
1826 input: json!({"text": "test"}),
1827 is_input_complete: true,
1828 thought_signature: None,
1829 };
1830 fake_model.send_last_completion_stream_text_chunk("Hi!");
1831 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1832 permission_tool_use,
1833 ));
1834 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1835 echo_tool_use.clone(),
1836 ));
1837 fake_model.end_last_completion_stream();
1838 cx.run_until_parked();
1839
1840 // Ensure pending tools are skipped when building a request.
1841 let request = thread
1842 .read_with(cx, |thread, cx| {
1843 thread.build_completion_request(CompletionIntent::EditFile, cx)
1844 })
1845 .unwrap();
1846 assert_eq!(
1847 request.messages[1..],
1848 vec![
1849 LanguageModelRequestMessage {
1850 role: Role::User,
1851 content: vec!["Hey!".into()],
1852 cache: true,
1853 reasoning_details: None,
1854 },
1855 LanguageModelRequestMessage {
1856 role: Role::Assistant,
1857 content: vec![
1858 MessageContent::Text("Hi!".into()),
1859 MessageContent::ToolUse(echo_tool_use.clone())
1860 ],
1861 cache: false,
1862 reasoning_details: None,
1863 },
1864 LanguageModelRequestMessage {
1865 role: Role::User,
1866 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1867 tool_use_id: echo_tool_use.id.clone(),
1868 tool_name: echo_tool_use.name,
1869 is_error: false,
1870 content: "test".into(),
1871 output: Some("test".into())
1872 })],
1873 cache: false,
1874 reasoning_details: None,
1875 },
1876 ],
1877 );
1878}
1879
1880#[gpui::test]
1881async fn test_agent_connection(cx: &mut TestAppContext) {
1882 cx.update(settings::init);
1883 let templates = Templates::new();
1884
1885 // Initialize language model system with test provider
1886 cx.update(|cx| {
1887 gpui_tokio::init(cx);
1888
1889 let http_client = FakeHttpClient::with_404_response();
1890 let clock = Arc::new(clock::FakeSystemClock::new());
1891 let client = Client::new(clock, http_client, cx);
1892 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1893 language_model::init(client.clone(), cx);
1894 language_models::init(user_store, client.clone(), cx);
1895 LanguageModelRegistry::test(cx);
1896 });
1897 cx.executor().forbid_parking();
1898
1899 // Create a project for new_thread
1900 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1901 fake_fs.insert_tree(path!("/test"), json!({})).await;
1902 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1903 let cwd = Path::new("/test");
1904 let text_thread_store =
1905 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1906 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1907
1908 // Create agent and connection
1909 let agent = NativeAgent::new(
1910 project.clone(),
1911 history_store,
1912 templates.clone(),
1913 None,
1914 fake_fs.clone(),
1915 &mut cx.to_async(),
1916 )
1917 .await
1918 .unwrap();
1919 let connection = NativeAgentConnection(agent.clone());
1920
1921 // Create a thread using new_thread
1922 let connection_rc = Rc::new(connection.clone());
1923 let acp_thread = cx
1924 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1925 .await
1926 .expect("new_thread should succeed");
1927
1928 // Get the session_id from the AcpThread
1929 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1930
1931 // Test model_selector returns Some
1932 let selector_opt = connection.model_selector(&session_id);
1933 assert!(
1934 selector_opt.is_some(),
1935 "agent should always support ModelSelector"
1936 );
1937 let selector = selector_opt.unwrap();
1938
1939 // Test list_models
1940 let listed_models = cx
1941 .update(|cx| selector.list_models(cx))
1942 .await
1943 .expect("list_models should succeed");
1944 let AgentModelList::Grouped(listed_models) = listed_models else {
1945 panic!("Unexpected model list type");
1946 };
1947 assert!(!listed_models.is_empty(), "should have at least one model");
1948 assert_eq!(
1949 listed_models[&AgentModelGroupName("Fake".into())][0]
1950 .id
1951 .0
1952 .as_ref(),
1953 "fake/fake"
1954 );
1955
1956 // Test selected_model returns the default
1957 let model = cx
1958 .update(|cx| selector.selected_model(cx))
1959 .await
1960 .expect("selected_model should succeed");
1961 let model = cx
1962 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1963 .unwrap();
1964 let model = model.as_fake();
1965 assert_eq!(model.id().0, "fake", "should return default model");
1966
1967 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1968 cx.run_until_parked();
1969 model.send_last_completion_stream_text_chunk("def");
1970 cx.run_until_parked();
1971 acp_thread.read_with(cx, |thread, cx| {
1972 assert_eq!(
1973 thread.to_markdown(cx),
1974 indoc! {"
1975 ## User
1976
1977 abc
1978
1979 ## Assistant
1980
1981 def
1982
1983 "}
1984 )
1985 });
1986
1987 // Test cancel
1988 cx.update(|cx| connection.cancel(&session_id, cx));
1989 request.await.expect("prompt should fail gracefully");
1990
1991 // Ensure that dropping the ACP thread causes the native thread to be
1992 // dropped as well.
1993 cx.update(|_| drop(acp_thread));
1994 let result = cx
1995 .update(|cx| {
1996 connection.prompt(
1997 Some(acp_thread::UserMessageId::new()),
1998 acp::PromptRequest {
1999 session_id: session_id.clone(),
2000 prompt: vec!["ghi".into()],
2001 meta: None,
2002 },
2003 cx,
2004 )
2005 })
2006 .await;
2007 assert_eq!(
2008 result.as_ref().unwrap_err().to_string(),
2009 "Session not found",
2010 "unexpected result: {:?}",
2011 result
2012 );
2013}
2014
2015#[gpui::test]
2016async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2017 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2018 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2019 let fake_model = model.as_fake();
2020
2021 let mut events = thread
2022 .update(cx, |thread, cx| {
2023 thread.send(UserMessageId::new(), ["Think"], cx)
2024 })
2025 .unwrap();
2026 cx.run_until_parked();
2027
2028 // Simulate streaming partial input.
2029 let input = json!({});
2030 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2031 LanguageModelToolUse {
2032 id: "1".into(),
2033 name: ThinkingTool::name().into(),
2034 raw_input: input.to_string(),
2035 input,
2036 is_input_complete: false,
2037 thought_signature: None,
2038 },
2039 ));
2040
2041 // Input streaming completed
2042 let input = json!({ "content": "Thinking hard!" });
2043 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2044 LanguageModelToolUse {
2045 id: "1".into(),
2046 name: "thinking".into(),
2047 raw_input: input.to_string(),
2048 input,
2049 is_input_complete: true,
2050 thought_signature: None,
2051 },
2052 ));
2053 fake_model.end_last_completion_stream();
2054 cx.run_until_parked();
2055
2056 let tool_call = expect_tool_call(&mut events).await;
2057 assert_eq!(
2058 tool_call,
2059 acp::ToolCall {
2060 id: acp::ToolCallId("1".into()),
2061 title: "Thinking".into(),
2062 kind: acp::ToolKind::Think,
2063 status: acp::ToolCallStatus::Pending,
2064 content: vec![],
2065 locations: vec![],
2066 raw_input: Some(json!({})),
2067 raw_output: None,
2068 meta: Some(json!({ "tool_name": "thinking" })),
2069 }
2070 );
2071 let update = expect_tool_call_update_fields(&mut events).await;
2072 assert_eq!(
2073 update,
2074 acp::ToolCallUpdate {
2075 id: acp::ToolCallId("1".into()),
2076 fields: acp::ToolCallUpdateFields {
2077 title: Some("Thinking".into()),
2078 kind: Some(acp::ToolKind::Think),
2079 raw_input: Some(json!({ "content": "Thinking hard!" })),
2080 ..Default::default()
2081 },
2082 meta: None,
2083 }
2084 );
2085 let update = expect_tool_call_update_fields(&mut events).await;
2086 assert_eq!(
2087 update,
2088 acp::ToolCallUpdate {
2089 id: acp::ToolCallId("1".into()),
2090 fields: acp::ToolCallUpdateFields {
2091 status: Some(acp::ToolCallStatus::InProgress),
2092 ..Default::default()
2093 },
2094 meta: None,
2095 }
2096 );
2097 let update = expect_tool_call_update_fields(&mut events).await;
2098 assert_eq!(
2099 update,
2100 acp::ToolCallUpdate {
2101 id: acp::ToolCallId("1".into()),
2102 fields: acp::ToolCallUpdateFields {
2103 content: Some(vec!["Thinking hard!".into()]),
2104 ..Default::default()
2105 },
2106 meta: None,
2107 }
2108 );
2109 let update = expect_tool_call_update_fields(&mut events).await;
2110 assert_eq!(
2111 update,
2112 acp::ToolCallUpdate {
2113 id: acp::ToolCallId("1".into()),
2114 fields: acp::ToolCallUpdateFields {
2115 status: Some(acp::ToolCallStatus::Completed),
2116 raw_output: Some("Finished thinking.".into()),
2117 ..Default::default()
2118 },
2119 meta: None,
2120 }
2121 );
2122}
2123
2124#[gpui::test]
2125async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2126 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2127 let fake_model = model.as_fake();
2128
2129 let mut events = thread
2130 .update(cx, |thread, cx| {
2131 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2132 thread.send(UserMessageId::new(), ["Hello!"], cx)
2133 })
2134 .unwrap();
2135 cx.run_until_parked();
2136
2137 fake_model.send_last_completion_stream_text_chunk("Hey!");
2138 fake_model.end_last_completion_stream();
2139
2140 let mut retry_events = Vec::new();
2141 while let Some(Ok(event)) = events.next().await {
2142 match event {
2143 ThreadEvent::Retry(retry_status) => {
2144 retry_events.push(retry_status);
2145 }
2146 ThreadEvent::Stop(..) => break,
2147 _ => {}
2148 }
2149 }
2150
2151 assert_eq!(retry_events.len(), 0);
2152 thread.read_with(cx, |thread, _cx| {
2153 assert_eq!(
2154 thread.to_markdown(),
2155 indoc! {"
2156 ## User
2157
2158 Hello!
2159
2160 ## Assistant
2161
2162 Hey!
2163 "}
2164 )
2165 });
2166}
2167
2168#[gpui::test]
2169async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2170 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2171 let fake_model = model.as_fake();
2172
2173 let mut events = thread
2174 .update(cx, |thread, cx| {
2175 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2176 thread.send(UserMessageId::new(), ["Hello!"], cx)
2177 })
2178 .unwrap();
2179 cx.run_until_parked();
2180
2181 fake_model.send_last_completion_stream_text_chunk("Hey,");
2182 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2183 provider: LanguageModelProviderName::new("Anthropic"),
2184 retry_after: Some(Duration::from_secs(3)),
2185 });
2186 fake_model.end_last_completion_stream();
2187
2188 cx.executor().advance_clock(Duration::from_secs(3));
2189 cx.run_until_parked();
2190
2191 fake_model.send_last_completion_stream_text_chunk("there!");
2192 fake_model.end_last_completion_stream();
2193 cx.run_until_parked();
2194
2195 let mut retry_events = Vec::new();
2196 while let Some(Ok(event)) = events.next().await {
2197 match event {
2198 ThreadEvent::Retry(retry_status) => {
2199 retry_events.push(retry_status);
2200 }
2201 ThreadEvent::Stop(..) => break,
2202 _ => {}
2203 }
2204 }
2205
2206 assert_eq!(retry_events.len(), 1);
2207 assert!(matches!(
2208 retry_events[0],
2209 acp_thread::RetryStatus { attempt: 1, .. }
2210 ));
2211 thread.read_with(cx, |thread, _cx| {
2212 assert_eq!(
2213 thread.to_markdown(),
2214 indoc! {"
2215 ## User
2216
2217 Hello!
2218
2219 ## Assistant
2220
2221 Hey,
2222
2223 [resume]
2224
2225 ## Assistant
2226
2227 there!
2228 "}
2229 )
2230 });
2231}
2232
2233#[gpui::test]
2234async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2235 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2236 let fake_model = model.as_fake();
2237
2238 let events = thread
2239 .update(cx, |thread, cx| {
2240 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2241 thread.add_tool(EchoTool);
2242 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2243 })
2244 .unwrap();
2245 cx.run_until_parked();
2246
2247 let tool_use_1 = LanguageModelToolUse {
2248 id: "tool_1".into(),
2249 name: EchoTool::name().into(),
2250 raw_input: json!({"text": "test"}).to_string(),
2251 input: json!({"text": "test"}),
2252 is_input_complete: true,
2253 thought_signature: None,
2254 };
2255 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2256 tool_use_1.clone(),
2257 ));
2258 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2259 provider: LanguageModelProviderName::new("Anthropic"),
2260 retry_after: Some(Duration::from_secs(3)),
2261 });
2262 fake_model.end_last_completion_stream();
2263
2264 cx.executor().advance_clock(Duration::from_secs(3));
2265 let completion = fake_model.pending_completions().pop().unwrap();
2266 assert_eq!(
2267 completion.messages[1..],
2268 vec![
2269 LanguageModelRequestMessage {
2270 role: Role::User,
2271 content: vec!["Call the echo tool!".into()],
2272 cache: false,
2273 reasoning_details: None,
2274 },
2275 LanguageModelRequestMessage {
2276 role: Role::Assistant,
2277 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2278 cache: false,
2279 reasoning_details: None,
2280 },
2281 LanguageModelRequestMessage {
2282 role: Role::User,
2283 content: vec![language_model::MessageContent::ToolResult(
2284 LanguageModelToolResult {
2285 tool_use_id: tool_use_1.id.clone(),
2286 tool_name: tool_use_1.name.clone(),
2287 is_error: false,
2288 content: "test".into(),
2289 output: Some("test".into())
2290 }
2291 )],
2292 cache: true,
2293 reasoning_details: None,
2294 },
2295 ]
2296 );
2297
2298 fake_model.send_last_completion_stream_text_chunk("Done");
2299 fake_model.end_last_completion_stream();
2300 cx.run_until_parked();
2301 events.collect::<Vec<_>>().await;
2302 thread.read_with(cx, |thread, _cx| {
2303 assert_eq!(
2304 thread.last_message(),
2305 Some(Message::Agent(AgentMessage {
2306 content: vec![AgentMessageContent::Text("Done".into())],
2307 tool_results: IndexMap::default(),
2308 reasoning_details: None,
2309 }))
2310 );
2311 })
2312}
2313
2314#[gpui::test]
2315async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2316 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2317 let fake_model = model.as_fake();
2318
2319 let mut events = thread
2320 .update(cx, |thread, cx| {
2321 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2322 thread.send(UserMessageId::new(), ["Hello!"], cx)
2323 })
2324 .unwrap();
2325 cx.run_until_parked();
2326
2327 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2328 fake_model.send_last_completion_stream_error(
2329 LanguageModelCompletionError::ServerOverloaded {
2330 provider: LanguageModelProviderName::new("Anthropic"),
2331 retry_after: Some(Duration::from_secs(3)),
2332 },
2333 );
2334 fake_model.end_last_completion_stream();
2335 cx.executor().advance_clock(Duration::from_secs(3));
2336 cx.run_until_parked();
2337 }
2338
2339 let mut errors = Vec::new();
2340 let mut retry_events = Vec::new();
2341 while let Some(event) = events.next().await {
2342 match event {
2343 Ok(ThreadEvent::Retry(retry_status)) => {
2344 retry_events.push(retry_status);
2345 }
2346 Ok(ThreadEvent::Stop(..)) => break,
2347 Err(error) => errors.push(error),
2348 _ => {}
2349 }
2350 }
2351
2352 assert_eq!(
2353 retry_events.len(),
2354 crate::thread::MAX_RETRY_ATTEMPTS as usize
2355 );
2356 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2357 assert_eq!(retry_events[i].attempt, i + 1);
2358 }
2359 assert_eq!(errors.len(), 1);
2360 let error = errors[0]
2361 .downcast_ref::<LanguageModelCompletionError>()
2362 .unwrap();
2363 assert!(matches!(
2364 error,
2365 LanguageModelCompletionError::ServerOverloaded { .. }
2366 ));
2367}
2368
2369/// Filters out the stop events for asserting against in tests
2370fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2371 result_events
2372 .into_iter()
2373 .filter_map(|event| match event.unwrap() {
2374 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2375 _ => None,
2376 })
2377 .collect()
2378}
2379
2380struct ThreadTest {
2381 model: Arc<dyn LanguageModel>,
2382 thread: Entity<Thread>,
2383 project_context: Entity<ProjectContext>,
2384 context_server_store: Entity<ContextServerStore>,
2385 fs: Arc<FakeFs>,
2386}
2387
2388enum TestModel {
2389 Sonnet4,
2390 Fake,
2391}
2392
2393impl TestModel {
2394 fn id(&self) -> LanguageModelId {
2395 match self {
2396 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2397 TestModel::Fake => unreachable!(),
2398 }
2399 }
2400}
2401
2402async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2403 cx.executor().allow_parking();
2404
2405 let fs = FakeFs::new(cx.background_executor.clone());
2406 fs.create_dir(paths::settings_file().parent().unwrap())
2407 .await
2408 .unwrap();
2409 fs.insert_file(
2410 paths::settings_file(),
2411 json!({
2412 "agent": {
2413 "default_profile": "test-profile",
2414 "profiles": {
2415 "test-profile": {
2416 "name": "Test Profile",
2417 "tools": {
2418 EchoTool::name(): true,
2419 DelayTool::name(): true,
2420 WordListTool::name(): true,
2421 ToolRequiringPermission::name(): true,
2422 InfiniteTool::name(): true,
2423 ThinkingTool::name(): true,
2424 }
2425 }
2426 }
2427 }
2428 })
2429 .to_string()
2430 .into_bytes(),
2431 )
2432 .await;
2433
2434 cx.update(|cx| {
2435 settings::init(cx);
2436
2437 match model {
2438 TestModel::Fake => {}
2439 TestModel::Sonnet4 => {
2440 gpui_tokio::init(cx);
2441 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2442 cx.set_http_client(Arc::new(http_client));
2443 let client = Client::production(cx);
2444 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2445 language_model::init(client.clone(), cx);
2446 language_models::init(user_store, client.clone(), cx);
2447 }
2448 };
2449
2450 watch_settings(fs.clone(), cx);
2451 });
2452
2453 let templates = Templates::new();
2454
2455 fs.insert_tree(path!("/test"), json!({})).await;
2456 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2457
2458 let model = cx
2459 .update(|cx| {
2460 if let TestModel::Fake = model {
2461 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2462 } else {
2463 let model_id = model.id();
2464 let models = LanguageModelRegistry::read_global(cx);
2465 let model = models
2466 .available_models(cx)
2467 .find(|model| model.id() == model_id)
2468 .unwrap();
2469
2470 let provider = models.provider(&model.provider_id()).unwrap();
2471 let authenticated = provider.authenticate(cx);
2472
2473 cx.spawn(async move |_cx| {
2474 authenticated.await.unwrap();
2475 model
2476 })
2477 }
2478 })
2479 .await;
2480
2481 let project_context = cx.new(|_cx| ProjectContext::default());
2482 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2483 let context_server_registry =
2484 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2485 let thread = cx.new(|cx| {
2486 Thread::new(
2487 project,
2488 project_context.clone(),
2489 context_server_registry,
2490 templates,
2491 Some(model.clone()),
2492 cx,
2493 )
2494 });
2495 ThreadTest {
2496 model,
2497 thread,
2498 project_context,
2499 context_server_store,
2500 fs,
2501 }
2502}
2503
2504#[cfg(test)]
2505#[ctor::ctor]
2506fn init_logger() {
2507 if std::env::var("RUST_LOG").is_ok() {
2508 env_logger::init();
2509 }
2510}
2511
2512fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2513 let fs = fs.clone();
2514 cx.spawn({
2515 async move |cx| {
2516 let mut new_settings_content_rx = settings::watch_config_file(
2517 cx.background_executor(),
2518 fs,
2519 paths::settings_file().clone(),
2520 );
2521
2522 while let Some(new_settings_content) = new_settings_content_rx.next().await {
2523 cx.update(|cx| {
2524 SettingsStore::update_global(cx, |settings, cx| {
2525 settings.set_user_settings(&new_settings_content, cx)
2526 })
2527 })
2528 .ok();
2529 }
2530 }
2531 })
2532 .detach();
2533}
2534
2535fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2536 completion
2537 .tools
2538 .iter()
2539 .map(|tool| tool.name.clone())
2540 .collect()
2541}
2542
2543fn setup_context_server(
2544 name: &'static str,
2545 tools: Vec<context_server::types::Tool>,
2546 context_server_store: &Entity<ContextServerStore>,
2547 cx: &mut TestAppContext,
2548) -> mpsc::UnboundedReceiver<(
2549 context_server::types::CallToolParams,
2550 oneshot::Sender<context_server::types::CallToolResponse>,
2551)> {
2552 cx.update(|cx| {
2553 let mut settings = ProjectSettings::get_global(cx).clone();
2554 settings.context_servers.insert(
2555 name.into(),
2556 project::project_settings::ContextServerSettings::Stdio {
2557 enabled: true,
2558 command: ContextServerCommand {
2559 path: "somebinary".into(),
2560 args: Vec::new(),
2561 env: None,
2562 timeout: None,
2563 },
2564 },
2565 );
2566 ProjectSettings::override_global(settings, cx);
2567 });
2568
2569 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2570 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2571 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2572 context_server::types::InitializeResponse {
2573 protocol_version: context_server::types::ProtocolVersion(
2574 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2575 ),
2576 server_info: context_server::types::Implementation {
2577 name: name.into(),
2578 version: "1.0.0".to_string(),
2579 },
2580 capabilities: context_server::types::ServerCapabilities {
2581 tools: Some(context_server::types::ToolsCapabilities {
2582 list_changed: Some(true),
2583 }),
2584 ..Default::default()
2585 },
2586 meta: None,
2587 }
2588 })
2589 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2590 let tools = tools.clone();
2591 async move {
2592 context_server::types::ListToolsResponse {
2593 tools,
2594 next_cursor: None,
2595 meta: None,
2596 }
2597 }
2598 })
2599 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2600 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2601 async move {
2602 let (response_tx, response_rx) = oneshot::channel();
2603 mcp_tool_calls_tx
2604 .unbounded_send((params, response_tx))
2605 .unwrap();
2606 response_rx.await.unwrap()
2607 }
2608 });
2609 context_server_store.update(cx, |store, cx| {
2610 store.start_server(
2611 Arc::new(ContextServer::new(
2612 ContextServerId(name.into()),
2613 Arc::new(fake_transport),
2614 )),
2615 cx,
2616 );
2617 });
2618 cx.run_until_parked();
2619 mcp_tool_calls_rx
2620}