1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use fs::{FakeFs, Fs};
11use futures::{
12 StreamExt,
13 channel::{
14 mpsc::{self, UnboundedReceiver},
15 oneshot,
16 },
17};
18use gpui::{
19 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
20};
21use indoc::indoc;
22use language_model::{
23 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
24 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
25 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
26 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
27};
28use pretty_assertions::assert_eq;
29use project::{
30 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
31};
32use prompt_store::ProjectContext;
33use reqwest_client::ReqwestClient;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use serde_json::json;
37use settings::{Settings, SettingsStore};
38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
39use util::path;
40
41mod test_tools;
42use test_tools::*;
43
44#[gpui::test]
45async fn test_echo(cx: &mut TestAppContext) {
46 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
47 let fake_model = model.as_fake();
48
49 let events = thread
50 .update(cx, |thread, cx| {
51 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
52 })
53 .unwrap();
54 cx.run_until_parked();
55 fake_model.send_last_completion_stream_text_chunk("Hello");
56 fake_model
57 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
58 fake_model.end_last_completion_stream();
59
60 let events = events.collect().await;
61 thread.update(cx, |thread, _cx| {
62 assert_eq!(
63 thread.last_message().unwrap().to_markdown(),
64 indoc! {"
65 ## Assistant
66
67 Hello
68 "}
69 )
70 });
71 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
72}
73
74#[gpui::test]
75async fn test_thinking(cx: &mut TestAppContext) {
76 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
77 let fake_model = model.as_fake();
78
79 let events = thread
80 .update(cx, |thread, cx| {
81 thread.send(
82 UserMessageId::new(),
83 [indoc! {"
84 Testing:
85
86 Generate a thinking step where you just think the word 'Think',
87 and have your final answer be 'Hello'
88 "}],
89 cx,
90 )
91 })
92 .unwrap();
93 cx.run_until_parked();
94 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
95 text: "Think".to_string(),
96 signature: None,
97 });
98 fake_model.send_last_completion_stream_text_chunk("Hello");
99 fake_model
100 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
101 fake_model.end_last_completion_stream();
102
103 let events = events.collect().await;
104 thread.update(cx, |thread, _cx| {
105 assert_eq!(
106 thread.last_message().unwrap().to_markdown(),
107 indoc! {"
108 ## Assistant
109
110 <think>Think</think>
111 Hello
112 "}
113 )
114 });
115 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
116}
117
118#[gpui::test]
119async fn test_system_prompt(cx: &mut TestAppContext) {
120 let ThreadTest {
121 model,
122 thread,
123 project_context,
124 ..
125 } = setup(cx, TestModel::Fake).await;
126 let fake_model = model.as_fake();
127
128 project_context.update(cx, |project_context, _cx| {
129 project_context.shell = "test-shell".into()
130 });
131 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
132 thread
133 .update(cx, |thread, cx| {
134 thread.send(UserMessageId::new(), ["abc"], cx)
135 })
136 .unwrap();
137 cx.run_until_parked();
138 let mut pending_completions = fake_model.pending_completions();
139 assert_eq!(
140 pending_completions.len(),
141 1,
142 "unexpected pending completions: {:?}",
143 pending_completions
144 );
145
146 let pending_completion = pending_completions.pop().unwrap();
147 assert_eq!(pending_completion.messages[0].role, Role::System);
148
149 let system_message = &pending_completion.messages[0];
150 let system_prompt = system_message.content[0].to_str().unwrap();
151 assert!(
152 system_prompt.contains("test-shell"),
153 "unexpected system message: {:?}",
154 system_message
155 );
156 assert!(
157 system_prompt.contains("## Fixing Diagnostics"),
158 "unexpected system message: {:?}",
159 system_message
160 );
161}
162
163#[gpui::test]
164async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
165 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
166 let fake_model = model.as_fake();
167
168 thread
169 .update(cx, |thread, cx| {
170 thread.send(UserMessageId::new(), ["abc"], cx)
171 })
172 .unwrap();
173 cx.run_until_parked();
174 let mut pending_completions = fake_model.pending_completions();
175 assert_eq!(
176 pending_completions.len(),
177 1,
178 "unexpected pending completions: {:?}",
179 pending_completions
180 );
181
182 let pending_completion = pending_completions.pop().unwrap();
183 assert_eq!(pending_completion.messages[0].role, Role::System);
184
185 let system_message = &pending_completion.messages[0];
186 let system_prompt = system_message.content[0].to_str().unwrap();
187 assert!(
188 !system_prompt.contains("## Tool Use"),
189 "unexpected system message: {:?}",
190 system_message
191 );
192 assert!(
193 !system_prompt.contains("## Fixing Diagnostics"),
194 "unexpected system message: {:?}",
195 system_message
196 );
197}
198
199#[gpui::test]
200async fn test_prompt_caching(cx: &mut TestAppContext) {
201 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
202 let fake_model = model.as_fake();
203
204 // Send initial user message and verify it's cached
205 thread
206 .update(cx, |thread, cx| {
207 thread.send(UserMessageId::new(), ["Message 1"], cx)
208 })
209 .unwrap();
210 cx.run_until_parked();
211
212 let completion = fake_model.pending_completions().pop().unwrap();
213 assert_eq!(
214 completion.messages[1..],
215 vec![LanguageModelRequestMessage {
216 role: Role::User,
217 content: vec!["Message 1".into()],
218 cache: true
219 }]
220 );
221 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
222 "Response to Message 1".into(),
223 ));
224 fake_model.end_last_completion_stream();
225 cx.run_until_parked();
226
227 // Send another user message and verify only the latest is cached
228 thread
229 .update(cx, |thread, cx| {
230 thread.send(UserMessageId::new(), ["Message 2"], cx)
231 })
232 .unwrap();
233 cx.run_until_parked();
234
235 let completion = fake_model.pending_completions().pop().unwrap();
236 assert_eq!(
237 completion.messages[1..],
238 vec![
239 LanguageModelRequestMessage {
240 role: Role::User,
241 content: vec!["Message 1".into()],
242 cache: false
243 },
244 LanguageModelRequestMessage {
245 role: Role::Assistant,
246 content: vec!["Response to Message 1".into()],
247 cache: false
248 },
249 LanguageModelRequestMessage {
250 role: Role::User,
251 content: vec!["Message 2".into()],
252 cache: true
253 }
254 ]
255 );
256 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
257 "Response to Message 2".into(),
258 ));
259 fake_model.end_last_completion_stream();
260 cx.run_until_parked();
261
262 // Simulate a tool call and verify that the latest tool result is cached
263 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
264 thread
265 .update(cx, |thread, cx| {
266 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
267 })
268 .unwrap();
269 cx.run_until_parked();
270
271 let tool_use = LanguageModelToolUse {
272 id: "tool_1".into(),
273 name: EchoTool::name().into(),
274 raw_input: json!({"text": "test"}).to_string(),
275 input: json!({"text": "test"}),
276 is_input_complete: true,
277 thought_signature: None,
278 };
279 fake_model
280 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
281 fake_model.end_last_completion_stream();
282 cx.run_until_parked();
283
284 let completion = fake_model.pending_completions().pop().unwrap();
285 let tool_result = LanguageModelToolResult {
286 tool_use_id: "tool_1".into(),
287 tool_name: EchoTool::name().into(),
288 is_error: false,
289 content: "test".into(),
290 output: Some("test".into()),
291 };
292 assert_eq!(
293 completion.messages[1..],
294 vec![
295 LanguageModelRequestMessage {
296 role: Role::User,
297 content: vec!["Message 1".into()],
298 cache: false
299 },
300 LanguageModelRequestMessage {
301 role: Role::Assistant,
302 content: vec!["Response to Message 1".into()],
303 cache: false
304 },
305 LanguageModelRequestMessage {
306 role: Role::User,
307 content: vec!["Message 2".into()],
308 cache: false
309 },
310 LanguageModelRequestMessage {
311 role: Role::Assistant,
312 content: vec!["Response to Message 2".into()],
313 cache: false
314 },
315 LanguageModelRequestMessage {
316 role: Role::User,
317 content: vec!["Use the echo tool".into()],
318 cache: false
319 },
320 LanguageModelRequestMessage {
321 role: Role::Assistant,
322 content: vec![MessageContent::ToolUse(tool_use)],
323 cache: false
324 },
325 LanguageModelRequestMessage {
326 role: Role::User,
327 content: vec![MessageContent::ToolResult(tool_result)],
328 cache: true
329 }
330 ]
331 );
332}
333
334#[gpui::test]
335#[cfg_attr(not(feature = "e2e"), ignore)]
336async fn test_basic_tool_calls(cx: &mut TestAppContext) {
337 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
338
339 // Test a tool call that's likely to complete *before* streaming stops.
340 let events = thread
341 .update(cx, |thread, cx| {
342 thread.add_tool(EchoTool);
343 thread.send(
344 UserMessageId::new(),
345 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
346 cx,
347 )
348 })
349 .unwrap()
350 .collect()
351 .await;
352 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
353
354 // Test a tool calls that's likely to complete *after* streaming stops.
355 let events = thread
356 .update(cx, |thread, cx| {
357 thread.remove_tool(&EchoTool::name());
358 thread.add_tool(DelayTool);
359 thread.send(
360 UserMessageId::new(),
361 [
362 "Now call the delay tool with 200ms.",
363 "When the timer goes off, then you echo the output of the tool.",
364 ],
365 cx,
366 )
367 })
368 .unwrap()
369 .collect()
370 .await;
371 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
372 thread.update(cx, |thread, _cx| {
373 assert!(
374 thread
375 .last_message()
376 .unwrap()
377 .as_agent_message()
378 .unwrap()
379 .content
380 .iter()
381 .any(|content| {
382 if let AgentMessageContent::Text(text) = content {
383 text.contains("Ding")
384 } else {
385 false
386 }
387 }),
388 "{}",
389 thread.to_markdown()
390 );
391 });
392}
393
394#[gpui::test]
395#[cfg_attr(not(feature = "e2e"), ignore)]
396async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
397 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
398
399 // Test a tool call that's likely to complete *before* streaming stops.
400 let mut events = thread
401 .update(cx, |thread, cx| {
402 thread.add_tool(WordListTool);
403 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
404 })
405 .unwrap();
406
407 let mut saw_partial_tool_use = false;
408 while let Some(event) = events.next().await {
409 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
410 thread.update(cx, |thread, _cx| {
411 // Look for a tool use in the thread's last message
412 let message = thread.last_message().unwrap();
413 let agent_message = message.as_agent_message().unwrap();
414 let last_content = agent_message.content.last().unwrap();
415 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
416 assert_eq!(last_tool_use.name.as_ref(), "word_list");
417 if tool_call.status == acp::ToolCallStatus::Pending {
418 if !last_tool_use.is_input_complete
419 && last_tool_use.input.get("g").is_none()
420 {
421 saw_partial_tool_use = true;
422 }
423 } else {
424 last_tool_use
425 .input
426 .get("a")
427 .expect("'a' has streamed because input is now complete");
428 last_tool_use
429 .input
430 .get("g")
431 .expect("'g' has streamed because input is now complete");
432 }
433 } else {
434 panic!("last content should be a tool use");
435 }
436 });
437 }
438 }
439
440 assert!(
441 saw_partial_tool_use,
442 "should see at least one partially streamed tool use in the history"
443 );
444}
445
446#[gpui::test]
447async fn test_tool_authorization(cx: &mut TestAppContext) {
448 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
449 let fake_model = model.as_fake();
450
451 let mut events = thread
452 .update(cx, |thread, cx| {
453 thread.add_tool(ToolRequiringPermission);
454 thread.send(UserMessageId::new(), ["abc"], cx)
455 })
456 .unwrap();
457 cx.run_until_parked();
458 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
459 LanguageModelToolUse {
460 id: "tool_id_1".into(),
461 name: ToolRequiringPermission::name().into(),
462 raw_input: "{}".into(),
463 input: json!({}),
464 is_input_complete: true,
465 thought_signature: None,
466 },
467 ));
468 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
469 LanguageModelToolUse {
470 id: "tool_id_2".into(),
471 name: ToolRequiringPermission::name().into(),
472 raw_input: "{}".into(),
473 input: json!({}),
474 is_input_complete: true,
475 thought_signature: None,
476 },
477 ));
478 fake_model.end_last_completion_stream();
479 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
480 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
481
482 // Approve the first
483 tool_call_auth_1
484 .response
485 .send(tool_call_auth_1.options[1].id.clone())
486 .unwrap();
487 cx.run_until_parked();
488
489 // Reject the second
490 tool_call_auth_2
491 .response
492 .send(tool_call_auth_1.options[2].id.clone())
493 .unwrap();
494 cx.run_until_parked();
495
496 let completion = fake_model.pending_completions().pop().unwrap();
497 let message = completion.messages.last().unwrap();
498 assert_eq!(
499 message.content,
500 vec![
501 language_model::MessageContent::ToolResult(LanguageModelToolResult {
502 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
503 tool_name: ToolRequiringPermission::name().into(),
504 is_error: false,
505 content: "Allowed".into(),
506 output: Some("Allowed".into())
507 }),
508 language_model::MessageContent::ToolResult(LanguageModelToolResult {
509 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
510 tool_name: ToolRequiringPermission::name().into(),
511 is_error: true,
512 content: "Permission to run tool denied by user".into(),
513 output: Some("Permission to run tool denied by user".into())
514 })
515 ]
516 );
517
518 // Simulate yet another tool call.
519 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
520 LanguageModelToolUse {
521 id: "tool_id_3".into(),
522 name: ToolRequiringPermission::name().into(),
523 raw_input: "{}".into(),
524 input: json!({}),
525 is_input_complete: true,
526 thought_signature: None,
527 },
528 ));
529 fake_model.end_last_completion_stream();
530
531 // Respond by always allowing tools.
532 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
533 tool_call_auth_3
534 .response
535 .send(tool_call_auth_3.options[0].id.clone())
536 .unwrap();
537 cx.run_until_parked();
538 let completion = fake_model.pending_completions().pop().unwrap();
539 let message = completion.messages.last().unwrap();
540 assert_eq!(
541 message.content,
542 vec![language_model::MessageContent::ToolResult(
543 LanguageModelToolResult {
544 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
545 tool_name: ToolRequiringPermission::name().into(),
546 is_error: false,
547 content: "Allowed".into(),
548 output: Some("Allowed".into())
549 }
550 )]
551 );
552
553 // Simulate a final tool call, ensuring we don't trigger authorization.
554 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
555 LanguageModelToolUse {
556 id: "tool_id_4".into(),
557 name: ToolRequiringPermission::name().into(),
558 raw_input: "{}".into(),
559 input: json!({}),
560 is_input_complete: true,
561 thought_signature: None,
562 },
563 ));
564 fake_model.end_last_completion_stream();
565 cx.run_until_parked();
566 let completion = fake_model.pending_completions().pop().unwrap();
567 let message = completion.messages.last().unwrap();
568 assert_eq!(
569 message.content,
570 vec![language_model::MessageContent::ToolResult(
571 LanguageModelToolResult {
572 tool_use_id: "tool_id_4".into(),
573 tool_name: ToolRequiringPermission::name().into(),
574 is_error: false,
575 content: "Allowed".into(),
576 output: Some("Allowed".into())
577 }
578 )]
579 );
580}
581
582#[gpui::test]
583async fn test_tool_hallucination(cx: &mut TestAppContext) {
584 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
585 let fake_model = model.as_fake();
586
587 let mut events = thread
588 .update(cx, |thread, cx| {
589 thread.send(UserMessageId::new(), ["abc"], cx)
590 })
591 .unwrap();
592 cx.run_until_parked();
593 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
594 LanguageModelToolUse {
595 id: "tool_id_1".into(),
596 name: "nonexistent_tool".into(),
597 raw_input: "{}".into(),
598 input: json!({}),
599 is_input_complete: true,
600 thought_signature: None,
601 },
602 ));
603 fake_model.end_last_completion_stream();
604
605 let tool_call = expect_tool_call(&mut events).await;
606 assert_eq!(tool_call.title, "nonexistent_tool");
607 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
608 let update = expect_tool_call_update_fields(&mut events).await;
609 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
610}
611
612#[gpui::test]
613async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
614 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
615 let fake_model = model.as_fake();
616
617 let events = thread
618 .update(cx, |thread, cx| {
619 thread.add_tool(EchoTool);
620 thread.send(UserMessageId::new(), ["abc"], cx)
621 })
622 .unwrap();
623 cx.run_until_parked();
624 let tool_use = LanguageModelToolUse {
625 id: "tool_id_1".into(),
626 name: EchoTool::name().into(),
627 raw_input: "{}".into(),
628 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
629 is_input_complete: true,
630 thought_signature: None,
631 };
632 fake_model
633 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
634 fake_model.end_last_completion_stream();
635
636 cx.run_until_parked();
637 let completion = fake_model.pending_completions().pop().unwrap();
638 let tool_result = LanguageModelToolResult {
639 tool_use_id: "tool_id_1".into(),
640 tool_name: EchoTool::name().into(),
641 is_error: false,
642 content: "def".into(),
643 output: Some("def".into()),
644 };
645 assert_eq!(
646 completion.messages[1..],
647 vec![
648 LanguageModelRequestMessage {
649 role: Role::User,
650 content: vec!["abc".into()],
651 cache: false
652 },
653 LanguageModelRequestMessage {
654 role: Role::Assistant,
655 content: vec![MessageContent::ToolUse(tool_use.clone())],
656 cache: false
657 },
658 LanguageModelRequestMessage {
659 role: Role::User,
660 content: vec![MessageContent::ToolResult(tool_result.clone())],
661 cache: true
662 },
663 ]
664 );
665
666 // Simulate reaching tool use limit.
667 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
668 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
669 ));
670 fake_model.end_last_completion_stream();
671 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
672 assert!(
673 last_event
674 .unwrap_err()
675 .is::<language_model::ToolUseLimitReachedError>()
676 );
677
678 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
679 cx.run_until_parked();
680 let completion = fake_model.pending_completions().pop().unwrap();
681 assert_eq!(
682 completion.messages[1..],
683 vec![
684 LanguageModelRequestMessage {
685 role: Role::User,
686 content: vec!["abc".into()],
687 cache: false
688 },
689 LanguageModelRequestMessage {
690 role: Role::Assistant,
691 content: vec![MessageContent::ToolUse(tool_use)],
692 cache: false
693 },
694 LanguageModelRequestMessage {
695 role: Role::User,
696 content: vec![MessageContent::ToolResult(tool_result)],
697 cache: false
698 },
699 LanguageModelRequestMessage {
700 role: Role::User,
701 content: vec!["Continue where you left off".into()],
702 cache: true
703 }
704 ]
705 );
706
707 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
708 fake_model.end_last_completion_stream();
709 events.collect::<Vec<_>>().await;
710 thread.read_with(cx, |thread, _cx| {
711 assert_eq!(
712 thread.last_message().unwrap().to_markdown(),
713 indoc! {"
714 ## Assistant
715
716 Done
717 "}
718 )
719 });
720}
721
722#[gpui::test]
723async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
724 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
725 let fake_model = model.as_fake();
726
727 let events = thread
728 .update(cx, |thread, cx| {
729 thread.add_tool(EchoTool);
730 thread.send(UserMessageId::new(), ["abc"], cx)
731 })
732 .unwrap();
733 cx.run_until_parked();
734
735 let tool_use = LanguageModelToolUse {
736 id: "tool_id_1".into(),
737 name: EchoTool::name().into(),
738 raw_input: "{}".into(),
739 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
740 is_input_complete: true,
741 thought_signature: None,
742 };
743 let tool_result = LanguageModelToolResult {
744 tool_use_id: "tool_id_1".into(),
745 tool_name: EchoTool::name().into(),
746 is_error: false,
747 content: "def".into(),
748 output: Some("def".into()),
749 };
750 fake_model
751 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
752 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
753 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
754 ));
755 fake_model.end_last_completion_stream();
756 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
757 assert!(
758 last_event
759 .unwrap_err()
760 .is::<language_model::ToolUseLimitReachedError>()
761 );
762
763 thread
764 .update(cx, |thread, cx| {
765 thread.send(UserMessageId::new(), vec!["ghi"], cx)
766 })
767 .unwrap();
768 cx.run_until_parked();
769 let completion = fake_model.pending_completions().pop().unwrap();
770 assert_eq!(
771 completion.messages[1..],
772 vec![
773 LanguageModelRequestMessage {
774 role: Role::User,
775 content: vec!["abc".into()],
776 cache: false
777 },
778 LanguageModelRequestMessage {
779 role: Role::Assistant,
780 content: vec![MessageContent::ToolUse(tool_use)],
781 cache: false
782 },
783 LanguageModelRequestMessage {
784 role: Role::User,
785 content: vec![MessageContent::ToolResult(tool_result)],
786 cache: false
787 },
788 LanguageModelRequestMessage {
789 role: Role::User,
790 content: vec!["ghi".into()],
791 cache: true
792 }
793 ]
794 );
795}
796
797async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
798 let event = events
799 .next()
800 .await
801 .expect("no tool call authorization event received")
802 .unwrap();
803 match event {
804 ThreadEvent::ToolCall(tool_call) => tool_call,
805 event => {
806 panic!("Unexpected event {event:?}");
807 }
808 }
809}
810
811async fn expect_tool_call_update_fields(
812 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
813) -> acp::ToolCallUpdate {
814 let event = events
815 .next()
816 .await
817 .expect("no tool call authorization event received")
818 .unwrap();
819 match event {
820 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
821 event => {
822 panic!("Unexpected event {event:?}");
823 }
824 }
825}
826
827async fn next_tool_call_authorization(
828 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
829) -> ToolCallAuthorization {
830 loop {
831 let event = events
832 .next()
833 .await
834 .expect("no tool call authorization event received")
835 .unwrap();
836 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
837 let permission_kinds = tool_call_authorization
838 .options
839 .iter()
840 .map(|o| o.kind)
841 .collect::<Vec<_>>();
842 assert_eq!(
843 permission_kinds,
844 vec![
845 acp::PermissionOptionKind::AllowAlways,
846 acp::PermissionOptionKind::AllowOnce,
847 acp::PermissionOptionKind::RejectOnce,
848 ]
849 );
850 return tool_call_authorization;
851 }
852 }
853}
854
855#[gpui::test]
856#[cfg_attr(not(feature = "e2e"), ignore)]
857async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
858 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
859
860 // Test concurrent tool calls with different delay times
861 let events = thread
862 .update(cx, |thread, cx| {
863 thread.add_tool(DelayTool);
864 thread.send(
865 UserMessageId::new(),
866 [
867 "Call the delay tool twice in the same message.",
868 "Once with 100ms. Once with 300ms.",
869 "When both timers are complete, describe the outputs.",
870 ],
871 cx,
872 )
873 })
874 .unwrap()
875 .collect()
876 .await;
877
878 let stop_reasons = stop_events(events);
879 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
880
881 thread.update(cx, |thread, _cx| {
882 let last_message = thread.last_message().unwrap();
883 let agent_message = last_message.as_agent_message().unwrap();
884 let text = agent_message
885 .content
886 .iter()
887 .filter_map(|content| {
888 if let AgentMessageContent::Text(text) = content {
889 Some(text.as_str())
890 } else {
891 None
892 }
893 })
894 .collect::<String>();
895
896 assert!(text.contains("Ding"));
897 });
898}
899
900#[gpui::test]
901async fn test_profiles(cx: &mut TestAppContext) {
902 let ThreadTest {
903 model, thread, fs, ..
904 } = setup(cx, TestModel::Fake).await;
905 let fake_model = model.as_fake();
906
907 thread.update(cx, |thread, _cx| {
908 thread.add_tool(DelayTool);
909 thread.add_tool(EchoTool);
910 thread.add_tool(InfiniteTool);
911 });
912
913 // Override profiles and wait for settings to be loaded.
914 fs.insert_file(
915 paths::settings_file(),
916 json!({
917 "agent": {
918 "profiles": {
919 "test-1": {
920 "name": "Test Profile 1",
921 "tools": {
922 EchoTool::name(): true,
923 DelayTool::name(): true,
924 }
925 },
926 "test-2": {
927 "name": "Test Profile 2",
928 "tools": {
929 InfiniteTool::name(): true,
930 }
931 }
932 }
933 }
934 })
935 .to_string()
936 .into_bytes(),
937 )
938 .await;
939 cx.run_until_parked();
940
941 // Test that test-1 profile (default) has echo and delay tools
942 thread
943 .update(cx, |thread, cx| {
944 thread.set_profile(AgentProfileId("test-1".into()));
945 thread.send(UserMessageId::new(), ["test"], cx)
946 })
947 .unwrap();
948 cx.run_until_parked();
949
950 let mut pending_completions = fake_model.pending_completions();
951 assert_eq!(pending_completions.len(), 1);
952 let completion = pending_completions.pop().unwrap();
953 let tool_names: Vec<String> = completion
954 .tools
955 .iter()
956 .map(|tool| tool.name.clone())
957 .collect();
958 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
959 fake_model.end_last_completion_stream();
960
961 // Switch to test-2 profile, and verify that it has only the infinite tool.
962 thread
963 .update(cx, |thread, cx| {
964 thread.set_profile(AgentProfileId("test-2".into()));
965 thread.send(UserMessageId::new(), ["test2"], cx)
966 })
967 .unwrap();
968 cx.run_until_parked();
969 let mut pending_completions = fake_model.pending_completions();
970 assert_eq!(pending_completions.len(), 1);
971 let completion = pending_completions.pop().unwrap();
972 let tool_names: Vec<String> = completion
973 .tools
974 .iter()
975 .map(|tool| tool.name.clone())
976 .collect();
977 assert_eq!(tool_names, vec![InfiniteTool::name()]);
978}
979
980#[gpui::test]
981async fn test_mcp_tools(cx: &mut TestAppContext) {
982 let ThreadTest {
983 model,
984 thread,
985 context_server_store,
986 fs,
987 ..
988 } = setup(cx, TestModel::Fake).await;
989 let fake_model = model.as_fake();
990
991 // Override profiles and wait for settings to be loaded.
992 fs.insert_file(
993 paths::settings_file(),
994 json!({
995 "agent": {
996 "always_allow_tool_actions": true,
997 "profiles": {
998 "test": {
999 "name": "Test Profile",
1000 "enable_all_context_servers": true,
1001 "tools": {
1002 EchoTool::name(): true,
1003 }
1004 },
1005 }
1006 }
1007 })
1008 .to_string()
1009 .into_bytes(),
1010 )
1011 .await;
1012 cx.run_until_parked();
1013 thread.update(cx, |thread, _| {
1014 thread.set_profile(AgentProfileId("test".into()))
1015 });
1016
1017 let mut mcp_tool_calls = setup_context_server(
1018 "test_server",
1019 vec![context_server::types::Tool {
1020 name: "echo".into(),
1021 description: None,
1022 input_schema: serde_json::to_value(EchoTool::input_schema(
1023 LanguageModelToolSchemaFormat::JsonSchema,
1024 ))
1025 .unwrap(),
1026 output_schema: None,
1027 annotations: None,
1028 }],
1029 &context_server_store,
1030 cx,
1031 );
1032
1033 let events = thread.update(cx, |thread, cx| {
1034 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1035 });
1036 cx.run_until_parked();
1037
1038 // Simulate the model calling the MCP tool.
1039 let completion = fake_model.pending_completions().pop().unwrap();
1040 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1041 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1042 LanguageModelToolUse {
1043 id: "tool_1".into(),
1044 name: "echo".into(),
1045 raw_input: json!({"text": "test"}).to_string(),
1046 input: json!({"text": "test"}),
1047 is_input_complete: true,
1048 thought_signature: None,
1049 },
1050 ));
1051 fake_model.end_last_completion_stream();
1052 cx.run_until_parked();
1053
1054 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1055 assert_eq!(tool_call_params.name, "echo");
1056 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1057 tool_call_response
1058 .send(context_server::types::CallToolResponse {
1059 content: vec![context_server::types::ToolResponseContent::Text {
1060 text: "test".into(),
1061 }],
1062 is_error: None,
1063 meta: None,
1064 structured_content: None,
1065 })
1066 .unwrap();
1067 cx.run_until_parked();
1068
1069 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1070 fake_model.send_last_completion_stream_text_chunk("Done!");
1071 fake_model.end_last_completion_stream();
1072 events.collect::<Vec<_>>().await;
1073
1074 // Send again after adding the echo tool, ensuring the name collision is resolved.
1075 let events = thread.update(cx, |thread, cx| {
1076 thread.add_tool(EchoTool);
1077 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1078 });
1079 cx.run_until_parked();
1080 let completion = fake_model.pending_completions().pop().unwrap();
1081 assert_eq!(
1082 tool_names_for_completion(&completion),
1083 vec!["echo", "test_server_echo"]
1084 );
1085 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1086 LanguageModelToolUse {
1087 id: "tool_2".into(),
1088 name: "test_server_echo".into(),
1089 raw_input: json!({"text": "mcp"}).to_string(),
1090 input: json!({"text": "mcp"}),
1091 is_input_complete: true,
1092 thought_signature: None,
1093 },
1094 ));
1095 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1096 LanguageModelToolUse {
1097 id: "tool_3".into(),
1098 name: "echo".into(),
1099 raw_input: json!({"text": "native"}).to_string(),
1100 input: json!({"text": "native"}),
1101 is_input_complete: true,
1102 thought_signature: None,
1103 },
1104 ));
1105 fake_model.end_last_completion_stream();
1106 cx.run_until_parked();
1107
1108 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1109 assert_eq!(tool_call_params.name, "echo");
1110 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1111 tool_call_response
1112 .send(context_server::types::CallToolResponse {
1113 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1114 is_error: None,
1115 meta: None,
1116 structured_content: None,
1117 })
1118 .unwrap();
1119 cx.run_until_parked();
1120
1121 // Ensure the tool results were inserted with the correct names.
1122 let completion = fake_model.pending_completions().pop().unwrap();
1123 assert_eq!(
1124 completion.messages.last().unwrap().content,
1125 vec![
1126 MessageContent::ToolResult(LanguageModelToolResult {
1127 tool_use_id: "tool_3".into(),
1128 tool_name: "echo".into(),
1129 is_error: false,
1130 content: "native".into(),
1131 output: Some("native".into()),
1132 },),
1133 MessageContent::ToolResult(LanguageModelToolResult {
1134 tool_use_id: "tool_2".into(),
1135 tool_name: "test_server_echo".into(),
1136 is_error: false,
1137 content: "mcp".into(),
1138 output: Some("mcp".into()),
1139 },),
1140 ]
1141 );
1142 fake_model.end_last_completion_stream();
1143 events.collect::<Vec<_>>().await;
1144}
1145
1146#[gpui::test]
1147async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1148 let ThreadTest {
1149 model,
1150 thread,
1151 context_server_store,
1152 fs,
1153 ..
1154 } = setup(cx, TestModel::Fake).await;
1155 let fake_model = model.as_fake();
1156
1157 // Set up a profile with all tools enabled
1158 fs.insert_file(
1159 paths::settings_file(),
1160 json!({
1161 "agent": {
1162 "profiles": {
1163 "test": {
1164 "name": "Test Profile",
1165 "enable_all_context_servers": true,
1166 "tools": {
1167 EchoTool::name(): true,
1168 DelayTool::name(): true,
1169 WordListTool::name(): true,
1170 ToolRequiringPermission::name(): true,
1171 InfiniteTool::name(): true,
1172 }
1173 },
1174 }
1175 }
1176 })
1177 .to_string()
1178 .into_bytes(),
1179 )
1180 .await;
1181 cx.run_until_parked();
1182
1183 thread.update(cx, |thread, _| {
1184 thread.set_profile(AgentProfileId("test".into()));
1185 thread.add_tool(EchoTool);
1186 thread.add_tool(DelayTool);
1187 thread.add_tool(WordListTool);
1188 thread.add_tool(ToolRequiringPermission);
1189 thread.add_tool(InfiniteTool);
1190 });
1191
1192 // Set up multiple context servers with some overlapping tool names
1193 let _server1_calls = setup_context_server(
1194 "xxx",
1195 vec![
1196 context_server::types::Tool {
1197 name: "echo".into(), // Conflicts with native EchoTool
1198 description: None,
1199 input_schema: serde_json::to_value(EchoTool::input_schema(
1200 LanguageModelToolSchemaFormat::JsonSchema,
1201 ))
1202 .unwrap(),
1203 output_schema: None,
1204 annotations: None,
1205 },
1206 context_server::types::Tool {
1207 name: "unique_tool_1".into(),
1208 description: None,
1209 input_schema: json!({"type": "object", "properties": {}}),
1210 output_schema: None,
1211 annotations: None,
1212 },
1213 ],
1214 &context_server_store,
1215 cx,
1216 );
1217
1218 let _server2_calls = setup_context_server(
1219 "yyy",
1220 vec![
1221 context_server::types::Tool {
1222 name: "echo".into(), // Also conflicts with native EchoTool
1223 description: None,
1224 input_schema: serde_json::to_value(EchoTool::input_schema(
1225 LanguageModelToolSchemaFormat::JsonSchema,
1226 ))
1227 .unwrap(),
1228 output_schema: None,
1229 annotations: None,
1230 },
1231 context_server::types::Tool {
1232 name: "unique_tool_2".into(),
1233 description: None,
1234 input_schema: json!({"type": "object", "properties": {}}),
1235 output_schema: None,
1236 annotations: None,
1237 },
1238 context_server::types::Tool {
1239 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1240 description: None,
1241 input_schema: json!({"type": "object", "properties": {}}),
1242 output_schema: None,
1243 annotations: None,
1244 },
1245 context_server::types::Tool {
1246 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1247 description: None,
1248 input_schema: json!({"type": "object", "properties": {}}),
1249 output_schema: None,
1250 annotations: None,
1251 },
1252 ],
1253 &context_server_store,
1254 cx,
1255 );
1256 let _server3_calls = setup_context_server(
1257 "zzz",
1258 vec![
1259 context_server::types::Tool {
1260 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1261 description: None,
1262 input_schema: json!({"type": "object", "properties": {}}),
1263 output_schema: None,
1264 annotations: None,
1265 },
1266 context_server::types::Tool {
1267 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1268 description: None,
1269 input_schema: json!({"type": "object", "properties": {}}),
1270 output_schema: None,
1271 annotations: None,
1272 },
1273 context_server::types::Tool {
1274 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1275 description: None,
1276 input_schema: json!({"type": "object", "properties": {}}),
1277 output_schema: None,
1278 annotations: None,
1279 },
1280 ],
1281 &context_server_store,
1282 cx,
1283 );
1284
1285 thread
1286 .update(cx, |thread, cx| {
1287 thread.send(UserMessageId::new(), ["Go"], cx)
1288 })
1289 .unwrap();
1290 cx.run_until_parked();
1291 let completion = fake_model.pending_completions().pop().unwrap();
1292 assert_eq!(
1293 tool_names_for_completion(&completion),
1294 vec![
1295 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1296 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1297 "delay",
1298 "echo",
1299 "infinite",
1300 "tool_requiring_permission",
1301 "unique_tool_1",
1302 "unique_tool_2",
1303 "word_list",
1304 "xxx_echo",
1305 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1306 "yyy_echo",
1307 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1308 ]
1309 );
1310}
1311
1312#[gpui::test]
1313#[cfg_attr(not(feature = "e2e"), ignore)]
1314async fn test_cancellation(cx: &mut TestAppContext) {
1315 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1316
1317 let mut events = thread
1318 .update(cx, |thread, cx| {
1319 thread.add_tool(InfiniteTool);
1320 thread.add_tool(EchoTool);
1321 thread.send(
1322 UserMessageId::new(),
1323 ["Call the echo tool, then call the infinite tool, then explain their output"],
1324 cx,
1325 )
1326 })
1327 .unwrap();
1328
1329 // Wait until both tools are called.
1330 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1331 let mut echo_id = None;
1332 let mut echo_completed = false;
1333 while let Some(event) = events.next().await {
1334 match event.unwrap() {
1335 ThreadEvent::ToolCall(tool_call) => {
1336 assert_eq!(tool_call.title, expected_tools.remove(0));
1337 if tool_call.title == "Echo" {
1338 echo_id = Some(tool_call.id);
1339 }
1340 }
1341 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1342 acp::ToolCallUpdate {
1343 id,
1344 fields:
1345 acp::ToolCallUpdateFields {
1346 status: Some(acp::ToolCallStatus::Completed),
1347 ..
1348 },
1349 meta: None,
1350 },
1351 )) if Some(&id) == echo_id.as_ref() => {
1352 echo_completed = true;
1353 }
1354 _ => {}
1355 }
1356
1357 if expected_tools.is_empty() && echo_completed {
1358 break;
1359 }
1360 }
1361
1362 // Cancel the current send and ensure that the event stream is closed, even
1363 // if one of the tools is still running.
1364 thread.update(cx, |thread, cx| thread.cancel(cx));
1365 let events = events.collect::<Vec<_>>().await;
1366 let last_event = events.last();
1367 assert!(
1368 matches!(
1369 last_event,
1370 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1371 ),
1372 "unexpected event {last_event:?}"
1373 );
1374
1375 // Ensure we can still send a new message after cancellation.
1376 let events = thread
1377 .update(cx, |thread, cx| {
1378 thread.send(
1379 UserMessageId::new(),
1380 ["Testing: reply with 'Hello' then stop."],
1381 cx,
1382 )
1383 })
1384 .unwrap()
1385 .collect::<Vec<_>>()
1386 .await;
1387 thread.update(cx, |thread, _cx| {
1388 let message = thread.last_message().unwrap();
1389 let agent_message = message.as_agent_message().unwrap();
1390 assert_eq!(
1391 agent_message.content,
1392 vec![AgentMessageContent::Text("Hello".to_string())]
1393 );
1394 });
1395 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1396}
1397
1398#[gpui::test]
1399async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1400 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1401 let fake_model = model.as_fake();
1402
1403 let events_1 = thread
1404 .update(cx, |thread, cx| {
1405 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1406 })
1407 .unwrap();
1408 cx.run_until_parked();
1409 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1410 cx.run_until_parked();
1411
1412 let events_2 = thread
1413 .update(cx, |thread, cx| {
1414 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1415 })
1416 .unwrap();
1417 cx.run_until_parked();
1418 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1419 fake_model
1420 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1421 fake_model.end_last_completion_stream();
1422
1423 let events_1 = events_1.collect::<Vec<_>>().await;
1424 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1425 let events_2 = events_2.collect::<Vec<_>>().await;
1426 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1427}
1428
1429#[gpui::test]
1430async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1431 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1432 let fake_model = model.as_fake();
1433
1434 let events_1 = thread
1435 .update(cx, |thread, cx| {
1436 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1437 })
1438 .unwrap();
1439 cx.run_until_parked();
1440 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1441 fake_model
1442 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1443 fake_model.end_last_completion_stream();
1444 let events_1 = events_1.collect::<Vec<_>>().await;
1445
1446 let events_2 = thread
1447 .update(cx, |thread, cx| {
1448 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1449 })
1450 .unwrap();
1451 cx.run_until_parked();
1452 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1453 fake_model
1454 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1455 fake_model.end_last_completion_stream();
1456 let events_2 = events_2.collect::<Vec<_>>().await;
1457
1458 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1459 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1460}
1461
1462#[gpui::test]
1463async fn test_refusal(cx: &mut TestAppContext) {
1464 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1465 let fake_model = model.as_fake();
1466
1467 let events = thread
1468 .update(cx, |thread, cx| {
1469 thread.send(UserMessageId::new(), ["Hello"], cx)
1470 })
1471 .unwrap();
1472 cx.run_until_parked();
1473 thread.read_with(cx, |thread, _| {
1474 assert_eq!(
1475 thread.to_markdown(),
1476 indoc! {"
1477 ## User
1478
1479 Hello
1480 "}
1481 );
1482 });
1483
1484 fake_model.send_last_completion_stream_text_chunk("Hey!");
1485 cx.run_until_parked();
1486 thread.read_with(cx, |thread, _| {
1487 assert_eq!(
1488 thread.to_markdown(),
1489 indoc! {"
1490 ## User
1491
1492 Hello
1493
1494 ## Assistant
1495
1496 Hey!
1497 "}
1498 );
1499 });
1500
1501 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1502 fake_model
1503 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1504 let events = events.collect::<Vec<_>>().await;
1505 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1506 thread.read_with(cx, |thread, _| {
1507 assert_eq!(thread.to_markdown(), "");
1508 });
1509}
1510
1511#[gpui::test]
1512async fn test_truncate_first_message(cx: &mut TestAppContext) {
1513 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1514 let fake_model = model.as_fake();
1515
1516 let message_id = UserMessageId::new();
1517 thread
1518 .update(cx, |thread, cx| {
1519 thread.send(message_id.clone(), ["Hello"], cx)
1520 })
1521 .unwrap();
1522 cx.run_until_parked();
1523 thread.read_with(cx, |thread, _| {
1524 assert_eq!(
1525 thread.to_markdown(),
1526 indoc! {"
1527 ## User
1528
1529 Hello
1530 "}
1531 );
1532 assert_eq!(thread.latest_token_usage(), None);
1533 });
1534
1535 fake_model.send_last_completion_stream_text_chunk("Hey!");
1536 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1537 language_model::TokenUsage {
1538 input_tokens: 32_000,
1539 output_tokens: 16_000,
1540 cache_creation_input_tokens: 0,
1541 cache_read_input_tokens: 0,
1542 },
1543 ));
1544 cx.run_until_parked();
1545 thread.read_with(cx, |thread, _| {
1546 assert_eq!(
1547 thread.to_markdown(),
1548 indoc! {"
1549 ## User
1550
1551 Hello
1552
1553 ## Assistant
1554
1555 Hey!
1556 "}
1557 );
1558 assert_eq!(
1559 thread.latest_token_usage(),
1560 Some(acp_thread::TokenUsage {
1561 used_tokens: 32_000 + 16_000,
1562 max_tokens: 1_000_000,
1563 })
1564 );
1565 });
1566
1567 thread
1568 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1569 .unwrap();
1570 cx.run_until_parked();
1571 thread.read_with(cx, |thread, _| {
1572 assert_eq!(thread.to_markdown(), "");
1573 assert_eq!(thread.latest_token_usage(), None);
1574 });
1575
1576 // Ensure we can still send a new message after truncation.
1577 thread
1578 .update(cx, |thread, cx| {
1579 thread.send(UserMessageId::new(), ["Hi"], cx)
1580 })
1581 .unwrap();
1582 thread.update(cx, |thread, _cx| {
1583 assert_eq!(
1584 thread.to_markdown(),
1585 indoc! {"
1586 ## User
1587
1588 Hi
1589 "}
1590 );
1591 });
1592 cx.run_until_parked();
1593 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1594 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1595 language_model::TokenUsage {
1596 input_tokens: 40_000,
1597 output_tokens: 20_000,
1598 cache_creation_input_tokens: 0,
1599 cache_read_input_tokens: 0,
1600 },
1601 ));
1602 cx.run_until_parked();
1603 thread.read_with(cx, |thread, _| {
1604 assert_eq!(
1605 thread.to_markdown(),
1606 indoc! {"
1607 ## User
1608
1609 Hi
1610
1611 ## Assistant
1612
1613 Ahoy!
1614 "}
1615 );
1616
1617 assert_eq!(
1618 thread.latest_token_usage(),
1619 Some(acp_thread::TokenUsage {
1620 used_tokens: 40_000 + 20_000,
1621 max_tokens: 1_000_000,
1622 })
1623 );
1624 });
1625}
1626
1627#[gpui::test]
1628async fn test_truncate_second_message(cx: &mut TestAppContext) {
1629 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1630 let fake_model = model.as_fake();
1631
1632 thread
1633 .update(cx, |thread, cx| {
1634 thread.send(UserMessageId::new(), ["Message 1"], cx)
1635 })
1636 .unwrap();
1637 cx.run_until_parked();
1638 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1639 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1640 language_model::TokenUsage {
1641 input_tokens: 32_000,
1642 output_tokens: 16_000,
1643 cache_creation_input_tokens: 0,
1644 cache_read_input_tokens: 0,
1645 },
1646 ));
1647 fake_model.end_last_completion_stream();
1648 cx.run_until_parked();
1649
1650 let assert_first_message_state = |cx: &mut TestAppContext| {
1651 thread.clone().read_with(cx, |thread, _| {
1652 assert_eq!(
1653 thread.to_markdown(),
1654 indoc! {"
1655 ## User
1656
1657 Message 1
1658
1659 ## Assistant
1660
1661 Message 1 response
1662 "}
1663 );
1664
1665 assert_eq!(
1666 thread.latest_token_usage(),
1667 Some(acp_thread::TokenUsage {
1668 used_tokens: 32_000 + 16_000,
1669 max_tokens: 1_000_000,
1670 })
1671 );
1672 });
1673 };
1674
1675 assert_first_message_state(cx);
1676
1677 let second_message_id = UserMessageId::new();
1678 thread
1679 .update(cx, |thread, cx| {
1680 thread.send(second_message_id.clone(), ["Message 2"], cx)
1681 })
1682 .unwrap();
1683 cx.run_until_parked();
1684
1685 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1686 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1687 language_model::TokenUsage {
1688 input_tokens: 40_000,
1689 output_tokens: 20_000,
1690 cache_creation_input_tokens: 0,
1691 cache_read_input_tokens: 0,
1692 },
1693 ));
1694 fake_model.end_last_completion_stream();
1695 cx.run_until_parked();
1696
1697 thread.read_with(cx, |thread, _| {
1698 assert_eq!(
1699 thread.to_markdown(),
1700 indoc! {"
1701 ## User
1702
1703 Message 1
1704
1705 ## Assistant
1706
1707 Message 1 response
1708
1709 ## User
1710
1711 Message 2
1712
1713 ## Assistant
1714
1715 Message 2 response
1716 "}
1717 );
1718
1719 assert_eq!(
1720 thread.latest_token_usage(),
1721 Some(acp_thread::TokenUsage {
1722 used_tokens: 40_000 + 20_000,
1723 max_tokens: 1_000_000,
1724 })
1725 );
1726 });
1727
1728 thread
1729 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1730 .unwrap();
1731 cx.run_until_parked();
1732
1733 assert_first_message_state(cx);
1734}
1735
1736#[gpui::test]
1737async fn test_title_generation(cx: &mut TestAppContext) {
1738 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1739 let fake_model = model.as_fake();
1740
1741 let summary_model = Arc::new(FakeLanguageModel::default());
1742 thread.update(cx, |thread, cx| {
1743 thread.set_summarization_model(Some(summary_model.clone()), cx)
1744 });
1745
1746 let send = thread
1747 .update(cx, |thread, cx| {
1748 thread.send(UserMessageId::new(), ["Hello"], cx)
1749 })
1750 .unwrap();
1751 cx.run_until_parked();
1752
1753 fake_model.send_last_completion_stream_text_chunk("Hey!");
1754 fake_model.end_last_completion_stream();
1755 cx.run_until_parked();
1756 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1757
1758 // Ensure the summary model has been invoked to generate a title.
1759 summary_model.send_last_completion_stream_text_chunk("Hello ");
1760 summary_model.send_last_completion_stream_text_chunk("world\nG");
1761 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1762 summary_model.end_last_completion_stream();
1763 send.collect::<Vec<_>>().await;
1764 cx.run_until_parked();
1765 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1766
1767 // Send another message, ensuring no title is generated this time.
1768 let send = thread
1769 .update(cx, |thread, cx| {
1770 thread.send(UserMessageId::new(), ["Hello again"], cx)
1771 })
1772 .unwrap();
1773 cx.run_until_parked();
1774 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1775 fake_model.end_last_completion_stream();
1776 cx.run_until_parked();
1777 assert_eq!(summary_model.pending_completions(), Vec::new());
1778 send.collect::<Vec<_>>().await;
1779 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1780}
1781
1782#[gpui::test]
1783async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1784 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1785 let fake_model = model.as_fake();
1786
1787 let _events = thread
1788 .update(cx, |thread, cx| {
1789 thread.add_tool(ToolRequiringPermission);
1790 thread.add_tool(EchoTool);
1791 thread.send(UserMessageId::new(), ["Hey!"], cx)
1792 })
1793 .unwrap();
1794 cx.run_until_parked();
1795
1796 let permission_tool_use = LanguageModelToolUse {
1797 id: "tool_id_1".into(),
1798 name: ToolRequiringPermission::name().into(),
1799 raw_input: "{}".into(),
1800 input: json!({}),
1801 is_input_complete: true,
1802 thought_signature: None,
1803 };
1804 let echo_tool_use = LanguageModelToolUse {
1805 id: "tool_id_2".into(),
1806 name: EchoTool::name().into(),
1807 raw_input: json!({"text": "test"}).to_string(),
1808 input: json!({"text": "test"}),
1809 is_input_complete: true,
1810 thought_signature: None,
1811 };
1812 fake_model.send_last_completion_stream_text_chunk("Hi!");
1813 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1814 permission_tool_use,
1815 ));
1816 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1817 echo_tool_use.clone(),
1818 ));
1819 fake_model.end_last_completion_stream();
1820 cx.run_until_parked();
1821
1822 // Ensure pending tools are skipped when building a request.
1823 let request = thread
1824 .read_with(cx, |thread, cx| {
1825 thread.build_completion_request(CompletionIntent::EditFile, cx)
1826 })
1827 .unwrap();
1828 assert_eq!(
1829 request.messages[1..],
1830 vec![
1831 LanguageModelRequestMessage {
1832 role: Role::User,
1833 content: vec!["Hey!".into()],
1834 cache: true
1835 },
1836 LanguageModelRequestMessage {
1837 role: Role::Assistant,
1838 content: vec![
1839 MessageContent::Text("Hi!".into()),
1840 MessageContent::ToolUse(echo_tool_use.clone())
1841 ],
1842 cache: false
1843 },
1844 LanguageModelRequestMessage {
1845 role: Role::User,
1846 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1847 tool_use_id: echo_tool_use.id.clone(),
1848 tool_name: echo_tool_use.name,
1849 is_error: false,
1850 content: "test".into(),
1851 output: Some("test".into())
1852 })],
1853 cache: false
1854 },
1855 ],
1856 );
1857}
1858
1859#[gpui::test]
1860async fn test_agent_connection(cx: &mut TestAppContext) {
1861 cx.update(settings::init);
1862 let templates = Templates::new();
1863
1864 // Initialize language model system with test provider
1865 cx.update(|cx| {
1866 gpui_tokio::init(cx);
1867 client::init_settings(cx);
1868
1869 let http_client = FakeHttpClient::with_404_response();
1870 let clock = Arc::new(clock::FakeSystemClock::new());
1871 let client = Client::new(clock, http_client, cx);
1872 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1873 language_model::init(client.clone(), cx);
1874 language_models::init(user_store, client.clone(), cx);
1875 Project::init_settings(cx);
1876 LanguageModelRegistry::test(cx);
1877 agent_settings::init(cx);
1878 });
1879 cx.executor().forbid_parking();
1880
1881 // Create a project for new_thread
1882 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1883 fake_fs.insert_tree(path!("/test"), json!({})).await;
1884 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1885 let cwd = Path::new("/test");
1886 let text_thread_store =
1887 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1888 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1889
1890 // Create agent and connection
1891 let agent = NativeAgent::new(
1892 project.clone(),
1893 history_store,
1894 templates.clone(),
1895 None,
1896 fake_fs.clone(),
1897 &mut cx.to_async(),
1898 )
1899 .await
1900 .unwrap();
1901 let connection = NativeAgentConnection(agent.clone());
1902
1903 // Create a thread using new_thread
1904 let connection_rc = Rc::new(connection.clone());
1905 let acp_thread = cx
1906 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1907 .await
1908 .expect("new_thread should succeed");
1909
1910 // Get the session_id from the AcpThread
1911 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1912
1913 // Test model_selector returns Some
1914 let selector_opt = connection.model_selector(&session_id);
1915 assert!(
1916 selector_opt.is_some(),
1917 "agent should always support ModelSelector"
1918 );
1919 let selector = selector_opt.unwrap();
1920
1921 // Test list_models
1922 let listed_models = cx
1923 .update(|cx| selector.list_models(cx))
1924 .await
1925 .expect("list_models should succeed");
1926 let AgentModelList::Grouped(listed_models) = listed_models else {
1927 panic!("Unexpected model list type");
1928 };
1929 assert!(!listed_models.is_empty(), "should have at least one model");
1930 assert_eq!(
1931 listed_models[&AgentModelGroupName("Fake".into())][0]
1932 .id
1933 .0
1934 .as_ref(),
1935 "fake/fake"
1936 );
1937
1938 // Test selected_model returns the default
1939 let model = cx
1940 .update(|cx| selector.selected_model(cx))
1941 .await
1942 .expect("selected_model should succeed");
1943 let model = cx
1944 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1945 .unwrap();
1946 let model = model.as_fake();
1947 assert_eq!(model.id().0, "fake", "should return default model");
1948
1949 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1950 cx.run_until_parked();
1951 model.send_last_completion_stream_text_chunk("def");
1952 cx.run_until_parked();
1953 acp_thread.read_with(cx, |thread, cx| {
1954 assert_eq!(
1955 thread.to_markdown(cx),
1956 indoc! {"
1957 ## User
1958
1959 abc
1960
1961 ## Assistant
1962
1963 def
1964
1965 "}
1966 )
1967 });
1968
1969 // Test cancel
1970 cx.update(|cx| connection.cancel(&session_id, cx));
1971 request.await.expect("prompt should fail gracefully");
1972
1973 // Ensure that dropping the ACP thread causes the native thread to be
1974 // dropped as well.
1975 cx.update(|_| drop(acp_thread));
1976 let result = cx
1977 .update(|cx| {
1978 connection.prompt(
1979 Some(acp_thread::UserMessageId::new()),
1980 acp::PromptRequest {
1981 session_id: session_id.clone(),
1982 prompt: vec!["ghi".into()],
1983 meta: None,
1984 },
1985 cx,
1986 )
1987 })
1988 .await;
1989 assert_eq!(
1990 result.as_ref().unwrap_err().to_string(),
1991 "Session not found",
1992 "unexpected result: {:?}",
1993 result
1994 );
1995}
1996
1997#[gpui::test]
1998async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1999 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2000 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2001 let fake_model = model.as_fake();
2002
2003 let mut events = thread
2004 .update(cx, |thread, cx| {
2005 thread.send(UserMessageId::new(), ["Think"], cx)
2006 })
2007 .unwrap();
2008 cx.run_until_parked();
2009
2010 // Simulate streaming partial input.
2011 let input = json!({});
2012 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2013 LanguageModelToolUse {
2014 id: "1".into(),
2015 name: ThinkingTool::name().into(),
2016 raw_input: input.to_string(),
2017 input,
2018 is_input_complete: false,
2019 thought_signature: None,
2020 },
2021 ));
2022
2023 // Input streaming completed
2024 let input = json!({ "content": "Thinking hard!" });
2025 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2026 LanguageModelToolUse {
2027 id: "1".into(),
2028 name: "thinking".into(),
2029 raw_input: input.to_string(),
2030 input,
2031 is_input_complete: true,
2032 thought_signature: None,
2033 },
2034 ));
2035 fake_model.end_last_completion_stream();
2036 cx.run_until_parked();
2037
2038 let tool_call = expect_tool_call(&mut events).await;
2039 assert_eq!(
2040 tool_call,
2041 acp::ToolCall {
2042 id: acp::ToolCallId("1".into()),
2043 title: "Thinking".into(),
2044 kind: acp::ToolKind::Think,
2045 status: acp::ToolCallStatus::Pending,
2046 content: vec![],
2047 locations: vec![],
2048 raw_input: Some(json!({})),
2049 raw_output: None,
2050 meta: Some(json!({ "tool_name": "thinking" })),
2051 }
2052 );
2053 let update = expect_tool_call_update_fields(&mut events).await;
2054 assert_eq!(
2055 update,
2056 acp::ToolCallUpdate {
2057 id: acp::ToolCallId("1".into()),
2058 fields: acp::ToolCallUpdateFields {
2059 title: Some("Thinking".into()),
2060 kind: Some(acp::ToolKind::Think),
2061 raw_input: Some(json!({ "content": "Thinking hard!" })),
2062 ..Default::default()
2063 },
2064 meta: None,
2065 }
2066 );
2067 let update = expect_tool_call_update_fields(&mut events).await;
2068 assert_eq!(
2069 update,
2070 acp::ToolCallUpdate {
2071 id: acp::ToolCallId("1".into()),
2072 fields: acp::ToolCallUpdateFields {
2073 status: Some(acp::ToolCallStatus::InProgress),
2074 ..Default::default()
2075 },
2076 meta: None,
2077 }
2078 );
2079 let update = expect_tool_call_update_fields(&mut events).await;
2080 assert_eq!(
2081 update,
2082 acp::ToolCallUpdate {
2083 id: acp::ToolCallId("1".into()),
2084 fields: acp::ToolCallUpdateFields {
2085 content: Some(vec!["Thinking hard!".into()]),
2086 ..Default::default()
2087 },
2088 meta: None,
2089 }
2090 );
2091 let update = expect_tool_call_update_fields(&mut events).await;
2092 assert_eq!(
2093 update,
2094 acp::ToolCallUpdate {
2095 id: acp::ToolCallId("1".into()),
2096 fields: acp::ToolCallUpdateFields {
2097 status: Some(acp::ToolCallStatus::Completed),
2098 raw_output: Some("Finished thinking.".into()),
2099 ..Default::default()
2100 },
2101 meta: None,
2102 }
2103 );
2104}
2105
2106#[gpui::test]
2107async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2108 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2109 let fake_model = model.as_fake();
2110
2111 let mut events = thread
2112 .update(cx, |thread, cx| {
2113 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2114 thread.send(UserMessageId::new(), ["Hello!"], cx)
2115 })
2116 .unwrap();
2117 cx.run_until_parked();
2118
2119 fake_model.send_last_completion_stream_text_chunk("Hey!");
2120 fake_model.end_last_completion_stream();
2121
2122 let mut retry_events = Vec::new();
2123 while let Some(Ok(event)) = events.next().await {
2124 match event {
2125 ThreadEvent::Retry(retry_status) => {
2126 retry_events.push(retry_status);
2127 }
2128 ThreadEvent::Stop(..) => break,
2129 _ => {}
2130 }
2131 }
2132
2133 assert_eq!(retry_events.len(), 0);
2134 thread.read_with(cx, |thread, _cx| {
2135 assert_eq!(
2136 thread.to_markdown(),
2137 indoc! {"
2138 ## User
2139
2140 Hello!
2141
2142 ## Assistant
2143
2144 Hey!
2145 "}
2146 )
2147 });
2148}
2149
2150#[gpui::test]
2151async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2152 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2153 let fake_model = model.as_fake();
2154
2155 let mut events = thread
2156 .update(cx, |thread, cx| {
2157 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2158 thread.send(UserMessageId::new(), ["Hello!"], cx)
2159 })
2160 .unwrap();
2161 cx.run_until_parked();
2162
2163 fake_model.send_last_completion_stream_text_chunk("Hey,");
2164 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2165 provider: LanguageModelProviderName::new("Anthropic"),
2166 retry_after: Some(Duration::from_secs(3)),
2167 });
2168 fake_model.end_last_completion_stream();
2169
2170 cx.executor().advance_clock(Duration::from_secs(3));
2171 cx.run_until_parked();
2172
2173 fake_model.send_last_completion_stream_text_chunk("there!");
2174 fake_model.end_last_completion_stream();
2175 cx.run_until_parked();
2176
2177 let mut retry_events = Vec::new();
2178 while let Some(Ok(event)) = events.next().await {
2179 match event {
2180 ThreadEvent::Retry(retry_status) => {
2181 retry_events.push(retry_status);
2182 }
2183 ThreadEvent::Stop(..) => break,
2184 _ => {}
2185 }
2186 }
2187
2188 assert_eq!(retry_events.len(), 1);
2189 assert!(matches!(
2190 retry_events[0],
2191 acp_thread::RetryStatus { attempt: 1, .. }
2192 ));
2193 thread.read_with(cx, |thread, _cx| {
2194 assert_eq!(
2195 thread.to_markdown(),
2196 indoc! {"
2197 ## User
2198
2199 Hello!
2200
2201 ## Assistant
2202
2203 Hey,
2204
2205 [resume]
2206
2207 ## Assistant
2208
2209 there!
2210 "}
2211 )
2212 });
2213}
2214
2215#[gpui::test]
2216async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2217 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2218 let fake_model = model.as_fake();
2219
2220 let events = thread
2221 .update(cx, |thread, cx| {
2222 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2223 thread.add_tool(EchoTool);
2224 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2225 })
2226 .unwrap();
2227 cx.run_until_parked();
2228
2229 let tool_use_1 = LanguageModelToolUse {
2230 id: "tool_1".into(),
2231 name: EchoTool::name().into(),
2232 raw_input: json!({"text": "test"}).to_string(),
2233 input: json!({"text": "test"}),
2234 is_input_complete: true,
2235 thought_signature: None,
2236 };
2237 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2238 tool_use_1.clone(),
2239 ));
2240 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2241 provider: LanguageModelProviderName::new("Anthropic"),
2242 retry_after: Some(Duration::from_secs(3)),
2243 });
2244 fake_model.end_last_completion_stream();
2245
2246 cx.executor().advance_clock(Duration::from_secs(3));
2247 let completion = fake_model.pending_completions().pop().unwrap();
2248 assert_eq!(
2249 completion.messages[1..],
2250 vec![
2251 LanguageModelRequestMessage {
2252 role: Role::User,
2253 content: vec!["Call the echo tool!".into()],
2254 cache: false
2255 },
2256 LanguageModelRequestMessage {
2257 role: Role::Assistant,
2258 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2259 cache: false
2260 },
2261 LanguageModelRequestMessage {
2262 role: Role::User,
2263 content: vec![language_model::MessageContent::ToolResult(
2264 LanguageModelToolResult {
2265 tool_use_id: tool_use_1.id.clone(),
2266 tool_name: tool_use_1.name.clone(),
2267 is_error: false,
2268 content: "test".into(),
2269 output: Some("test".into())
2270 }
2271 )],
2272 cache: true
2273 },
2274 ]
2275 );
2276
2277 fake_model.send_last_completion_stream_text_chunk("Done");
2278 fake_model.end_last_completion_stream();
2279 cx.run_until_parked();
2280 events.collect::<Vec<_>>().await;
2281 thread.read_with(cx, |thread, _cx| {
2282 assert_eq!(
2283 thread.last_message(),
2284 Some(Message::Agent(AgentMessage {
2285 content: vec![AgentMessageContent::Text("Done".into())],
2286 tool_results: IndexMap::default()
2287 }))
2288 );
2289 })
2290}
2291
2292#[gpui::test]
2293async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2294 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2295 let fake_model = model.as_fake();
2296
2297 let mut events = thread
2298 .update(cx, |thread, cx| {
2299 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2300 thread.send(UserMessageId::new(), ["Hello!"], cx)
2301 })
2302 .unwrap();
2303 cx.run_until_parked();
2304
2305 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2306 fake_model.send_last_completion_stream_error(
2307 LanguageModelCompletionError::ServerOverloaded {
2308 provider: LanguageModelProviderName::new("Anthropic"),
2309 retry_after: Some(Duration::from_secs(3)),
2310 },
2311 );
2312 fake_model.end_last_completion_stream();
2313 cx.executor().advance_clock(Duration::from_secs(3));
2314 cx.run_until_parked();
2315 }
2316
2317 let mut errors = Vec::new();
2318 let mut retry_events = Vec::new();
2319 while let Some(event) = events.next().await {
2320 match event {
2321 Ok(ThreadEvent::Retry(retry_status)) => {
2322 retry_events.push(retry_status);
2323 }
2324 Ok(ThreadEvent::Stop(..)) => break,
2325 Err(error) => errors.push(error),
2326 _ => {}
2327 }
2328 }
2329
2330 assert_eq!(
2331 retry_events.len(),
2332 crate::thread::MAX_RETRY_ATTEMPTS as usize
2333 );
2334 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2335 assert_eq!(retry_events[i].attempt, i + 1);
2336 }
2337 assert_eq!(errors.len(), 1);
2338 let error = errors[0]
2339 .downcast_ref::<LanguageModelCompletionError>()
2340 .unwrap();
2341 assert!(matches!(
2342 error,
2343 LanguageModelCompletionError::ServerOverloaded { .. }
2344 ));
2345}
2346
2347/// Filters out the stop events for asserting against in tests
2348fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2349 result_events
2350 .into_iter()
2351 .filter_map(|event| match event.unwrap() {
2352 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2353 _ => None,
2354 })
2355 .collect()
2356}
2357
2358struct ThreadTest {
2359 model: Arc<dyn LanguageModel>,
2360 thread: Entity<Thread>,
2361 project_context: Entity<ProjectContext>,
2362 context_server_store: Entity<ContextServerStore>,
2363 fs: Arc<FakeFs>,
2364}
2365
2366enum TestModel {
2367 Sonnet4,
2368 Fake,
2369}
2370
2371impl TestModel {
2372 fn id(&self) -> LanguageModelId {
2373 match self {
2374 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2375 TestModel::Fake => unreachable!(),
2376 }
2377 }
2378}
2379
2380async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2381 cx.executor().allow_parking();
2382
2383 let fs = FakeFs::new(cx.background_executor.clone());
2384 fs.create_dir(paths::settings_file().parent().unwrap())
2385 .await
2386 .unwrap();
2387 fs.insert_file(
2388 paths::settings_file(),
2389 json!({
2390 "agent": {
2391 "default_profile": "test-profile",
2392 "profiles": {
2393 "test-profile": {
2394 "name": "Test Profile",
2395 "tools": {
2396 EchoTool::name(): true,
2397 DelayTool::name(): true,
2398 WordListTool::name(): true,
2399 ToolRequiringPermission::name(): true,
2400 InfiniteTool::name(): true,
2401 ThinkingTool::name(): true,
2402 }
2403 }
2404 }
2405 }
2406 })
2407 .to_string()
2408 .into_bytes(),
2409 )
2410 .await;
2411
2412 cx.update(|cx| {
2413 settings::init(cx);
2414 Project::init_settings(cx);
2415 agent_settings::init(cx);
2416
2417 match model {
2418 TestModel::Fake => {}
2419 TestModel::Sonnet4 => {
2420 gpui_tokio::init(cx);
2421 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2422 cx.set_http_client(Arc::new(http_client));
2423 client::init_settings(cx);
2424 let client = Client::production(cx);
2425 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2426 language_model::init(client.clone(), cx);
2427 language_models::init(user_store, client.clone(), cx);
2428 }
2429 };
2430
2431 watch_settings(fs.clone(), cx);
2432 });
2433
2434 let templates = Templates::new();
2435
2436 fs.insert_tree(path!("/test"), json!({})).await;
2437 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2438
2439 let model = cx
2440 .update(|cx| {
2441 if let TestModel::Fake = model {
2442 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2443 } else {
2444 let model_id = model.id();
2445 let models = LanguageModelRegistry::read_global(cx);
2446 let model = models
2447 .available_models(cx)
2448 .find(|model| model.id() == model_id)
2449 .unwrap();
2450
2451 let provider = models.provider(&model.provider_id()).unwrap();
2452 let authenticated = provider.authenticate(cx);
2453
2454 cx.spawn(async move |_cx| {
2455 authenticated.await.unwrap();
2456 model
2457 })
2458 }
2459 })
2460 .await;
2461
2462 let project_context = cx.new(|_cx| ProjectContext::default());
2463 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2464 let context_server_registry =
2465 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2466 let thread = cx.new(|cx| {
2467 Thread::new(
2468 project,
2469 project_context.clone(),
2470 context_server_registry,
2471 templates,
2472 Some(model.clone()),
2473 cx,
2474 )
2475 });
2476 ThreadTest {
2477 model,
2478 thread,
2479 project_context,
2480 context_server_store,
2481 fs,
2482 }
2483}
2484
2485#[cfg(test)]
2486#[ctor::ctor]
2487fn init_logger() {
2488 if std::env::var("RUST_LOG").is_ok() {
2489 env_logger::init();
2490 }
2491}
2492
2493fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2494 let fs = fs.clone();
2495 cx.spawn({
2496 async move |cx| {
2497 let mut new_settings_content_rx = settings::watch_config_file(
2498 cx.background_executor(),
2499 fs,
2500 paths::settings_file().clone(),
2501 );
2502
2503 while let Some(new_settings_content) = new_settings_content_rx.next().await {
2504 cx.update(|cx| {
2505 SettingsStore::update_global(cx, |settings, cx| {
2506 settings.set_user_settings(&new_settings_content, cx)
2507 })
2508 })
2509 .ok();
2510 }
2511 }
2512 })
2513 .detach();
2514}
2515
2516fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2517 completion
2518 .tools
2519 .iter()
2520 .map(|tool| tool.name.clone())
2521 .collect()
2522}
2523
2524fn setup_context_server(
2525 name: &'static str,
2526 tools: Vec<context_server::types::Tool>,
2527 context_server_store: &Entity<ContextServerStore>,
2528 cx: &mut TestAppContext,
2529) -> mpsc::UnboundedReceiver<(
2530 context_server::types::CallToolParams,
2531 oneshot::Sender<context_server::types::CallToolResponse>,
2532)> {
2533 cx.update(|cx| {
2534 let mut settings = ProjectSettings::get_global(cx).clone();
2535 settings.context_servers.insert(
2536 name.into(),
2537 project::project_settings::ContextServerSettings::Custom {
2538 enabled: true,
2539 command: ContextServerCommand {
2540 path: "somebinary".into(),
2541 args: Vec::new(),
2542 env: None,
2543 timeout: None,
2544 },
2545 },
2546 );
2547 ProjectSettings::override_global(settings, cx);
2548 });
2549
2550 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2551 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2552 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2553 context_server::types::InitializeResponse {
2554 protocol_version: context_server::types::ProtocolVersion(
2555 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2556 ),
2557 server_info: context_server::types::Implementation {
2558 name: name.into(),
2559 version: "1.0.0".to_string(),
2560 },
2561 capabilities: context_server::types::ServerCapabilities {
2562 tools: Some(context_server::types::ToolsCapabilities {
2563 list_changed: Some(true),
2564 }),
2565 ..Default::default()
2566 },
2567 meta: None,
2568 }
2569 })
2570 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2571 let tools = tools.clone();
2572 async move {
2573 context_server::types::ListToolsResponse {
2574 tools,
2575 next_cursor: None,
2576 meta: None,
2577 }
2578 }
2579 })
2580 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2581 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2582 async move {
2583 let (response_tx, response_rx) = oneshot::channel();
2584 mcp_tool_calls_tx
2585 .unbounded_send((params, response_tx))
2586 .unwrap();
2587 response_rx.await.unwrap()
2588 }
2589 });
2590 context_server_store.update(cx, |store, cx| {
2591 store.start_server(
2592 Arc::new(ContextServer::new(
2593 ContextServerId(name.into()),
2594 Arc::new(fake_transport),
2595 )),
2596 cx,
2597 );
2598 });
2599 cx.run_until_parked();
2600 mcp_tool_calls_rx
2601}