1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use fs::{FakeFs, Fs};
11use futures::{
12 StreamExt,
13 channel::{
14 mpsc::{self, UnboundedReceiver},
15 oneshot,
16 },
17};
18use gpui::{
19 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
20};
21use indoc::indoc;
22use language_model::{
23 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
24 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
25 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
26 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
27};
28use pretty_assertions::assert_eq;
29use project::{
30 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
31};
32use prompt_store::ProjectContext;
33use reqwest_client::ReqwestClient;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use serde_json::json;
37use settings::{Settings, SettingsStore};
38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
39use util::path;
40
41mod test_tools;
42use test_tools::*;
43
44#[gpui::test]
45async fn test_echo(cx: &mut TestAppContext) {
46 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
47 let fake_model = model.as_fake();
48
49 let events = thread
50 .update(cx, |thread, cx| {
51 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
52 })
53 .unwrap();
54 cx.run_until_parked();
55 fake_model.send_last_completion_stream_text_chunk("Hello");
56 fake_model
57 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
58 fake_model.end_last_completion_stream();
59
60 let events = events.collect().await;
61 thread.update(cx, |thread, _cx| {
62 assert_eq!(
63 thread.last_message().unwrap().to_markdown(),
64 indoc! {"
65 ## Assistant
66
67 Hello
68 "}
69 )
70 });
71 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
72}
73
74#[gpui::test]
75async fn test_thinking(cx: &mut TestAppContext) {
76 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
77 let fake_model = model.as_fake();
78
79 let events = thread
80 .update(cx, |thread, cx| {
81 thread.send(
82 UserMessageId::new(),
83 [indoc! {"
84 Testing:
85
86 Generate a thinking step where you just think the word 'Think',
87 and have your final answer be 'Hello'
88 "}],
89 cx,
90 )
91 })
92 .unwrap();
93 cx.run_until_parked();
94 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
95 text: "Think".to_string(),
96 signature: None,
97 });
98 fake_model.send_last_completion_stream_text_chunk("Hello");
99 fake_model
100 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
101 fake_model.end_last_completion_stream();
102
103 let events = events.collect().await;
104 thread.update(cx, |thread, _cx| {
105 assert_eq!(
106 thread.last_message().unwrap().to_markdown(),
107 indoc! {"
108 ## Assistant
109
110 <think>Think</think>
111 Hello
112 "}
113 )
114 });
115 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
116}
117
118#[gpui::test]
119async fn test_system_prompt(cx: &mut TestAppContext) {
120 let ThreadTest {
121 model,
122 thread,
123 project_context,
124 ..
125 } = setup(cx, TestModel::Fake).await;
126 let fake_model = model.as_fake();
127
128 project_context.update(cx, |project_context, _cx| {
129 project_context.shell = "test-shell".into()
130 });
131 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
132 thread
133 .update(cx, |thread, cx| {
134 thread.send(UserMessageId::new(), ["abc"], cx)
135 })
136 .unwrap();
137 cx.run_until_parked();
138 let mut pending_completions = fake_model.pending_completions();
139 assert_eq!(
140 pending_completions.len(),
141 1,
142 "unexpected pending completions: {:?}",
143 pending_completions
144 );
145
146 let pending_completion = pending_completions.pop().unwrap();
147 assert_eq!(pending_completion.messages[0].role, Role::System);
148
149 let system_message = &pending_completion.messages[0];
150 let system_prompt = system_message.content[0].to_str().unwrap();
151 assert!(
152 system_prompt.contains("test-shell"),
153 "unexpected system message: {:?}",
154 system_message
155 );
156 assert!(
157 system_prompt.contains("## Fixing Diagnostics"),
158 "unexpected system message: {:?}",
159 system_message
160 );
161}
162
163#[gpui::test]
164async fn test_prompt_caching(cx: &mut TestAppContext) {
165 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
166 let fake_model = model.as_fake();
167
168 // Send initial user message and verify it's cached
169 thread
170 .update(cx, |thread, cx| {
171 thread.send(UserMessageId::new(), ["Message 1"], cx)
172 })
173 .unwrap();
174 cx.run_until_parked();
175
176 let completion = fake_model.pending_completions().pop().unwrap();
177 assert_eq!(
178 completion.messages[1..],
179 vec![LanguageModelRequestMessage {
180 role: Role::User,
181 content: vec!["Message 1".into()],
182 cache: true
183 }]
184 );
185 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
186 "Response to Message 1".into(),
187 ));
188 fake_model.end_last_completion_stream();
189 cx.run_until_parked();
190
191 // Send another user message and verify only the latest is cached
192 thread
193 .update(cx, |thread, cx| {
194 thread.send(UserMessageId::new(), ["Message 2"], cx)
195 })
196 .unwrap();
197 cx.run_until_parked();
198
199 let completion = fake_model.pending_completions().pop().unwrap();
200 assert_eq!(
201 completion.messages[1..],
202 vec![
203 LanguageModelRequestMessage {
204 role: Role::User,
205 content: vec!["Message 1".into()],
206 cache: false
207 },
208 LanguageModelRequestMessage {
209 role: Role::Assistant,
210 content: vec!["Response to Message 1".into()],
211 cache: false
212 },
213 LanguageModelRequestMessage {
214 role: Role::User,
215 content: vec!["Message 2".into()],
216 cache: true
217 }
218 ]
219 );
220 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
221 "Response to Message 2".into(),
222 ));
223 fake_model.end_last_completion_stream();
224 cx.run_until_parked();
225
226 // Simulate a tool call and verify that the latest tool result is cached
227 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
228 thread
229 .update(cx, |thread, cx| {
230 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
231 })
232 .unwrap();
233 cx.run_until_parked();
234
235 let tool_use = LanguageModelToolUse {
236 id: "tool_1".into(),
237 name: EchoTool::name().into(),
238 raw_input: json!({"text": "test"}).to_string(),
239 input: json!({"text": "test"}),
240 is_input_complete: true,
241 };
242 fake_model
243 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
244 fake_model.end_last_completion_stream();
245 cx.run_until_parked();
246
247 let completion = fake_model.pending_completions().pop().unwrap();
248 let tool_result = LanguageModelToolResult {
249 tool_use_id: "tool_1".into(),
250 tool_name: EchoTool::name().into(),
251 is_error: false,
252 content: "test".into(),
253 output: Some("test".into()),
254 };
255 assert_eq!(
256 completion.messages[1..],
257 vec![
258 LanguageModelRequestMessage {
259 role: Role::User,
260 content: vec!["Message 1".into()],
261 cache: false
262 },
263 LanguageModelRequestMessage {
264 role: Role::Assistant,
265 content: vec!["Response to Message 1".into()],
266 cache: false
267 },
268 LanguageModelRequestMessage {
269 role: Role::User,
270 content: vec!["Message 2".into()],
271 cache: false
272 },
273 LanguageModelRequestMessage {
274 role: Role::Assistant,
275 content: vec!["Response to Message 2".into()],
276 cache: false
277 },
278 LanguageModelRequestMessage {
279 role: Role::User,
280 content: vec!["Use the echo tool".into()],
281 cache: false
282 },
283 LanguageModelRequestMessage {
284 role: Role::Assistant,
285 content: vec![MessageContent::ToolUse(tool_use)],
286 cache: false
287 },
288 LanguageModelRequestMessage {
289 role: Role::User,
290 content: vec![MessageContent::ToolResult(tool_result)],
291 cache: true
292 }
293 ]
294 );
295}
296
297#[gpui::test]
298#[cfg_attr(not(feature = "e2e"), ignore)]
299async fn test_basic_tool_calls(cx: &mut TestAppContext) {
300 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
301
302 // Test a tool call that's likely to complete *before* streaming stops.
303 let events = thread
304 .update(cx, |thread, cx| {
305 thread.add_tool(EchoTool);
306 thread.send(
307 UserMessageId::new(),
308 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
309 cx,
310 )
311 })
312 .unwrap()
313 .collect()
314 .await;
315 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
316
317 // Test a tool calls that's likely to complete *after* streaming stops.
318 let events = thread
319 .update(cx, |thread, cx| {
320 thread.remove_tool(&EchoTool::name());
321 thread.add_tool(DelayTool);
322 thread.send(
323 UserMessageId::new(),
324 [
325 "Now call the delay tool with 200ms.",
326 "When the timer goes off, then you echo the output of the tool.",
327 ],
328 cx,
329 )
330 })
331 .unwrap()
332 .collect()
333 .await;
334 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
335 thread.update(cx, |thread, _cx| {
336 assert!(
337 thread
338 .last_message()
339 .unwrap()
340 .as_agent_message()
341 .unwrap()
342 .content
343 .iter()
344 .any(|content| {
345 if let AgentMessageContent::Text(text) = content {
346 text.contains("Ding")
347 } else {
348 false
349 }
350 }),
351 "{}",
352 thread.to_markdown()
353 );
354 });
355}
356
357#[gpui::test]
358#[cfg_attr(not(feature = "e2e"), ignore)]
359async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
360 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
361
362 // Test a tool call that's likely to complete *before* streaming stops.
363 let mut events = thread
364 .update(cx, |thread, cx| {
365 thread.add_tool(WordListTool);
366 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
367 })
368 .unwrap();
369
370 let mut saw_partial_tool_use = false;
371 while let Some(event) = events.next().await {
372 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
373 thread.update(cx, |thread, _cx| {
374 // Look for a tool use in the thread's last message
375 let message = thread.last_message().unwrap();
376 let agent_message = message.as_agent_message().unwrap();
377 let last_content = agent_message.content.last().unwrap();
378 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
379 assert_eq!(last_tool_use.name.as_ref(), "word_list");
380 if tool_call.status == acp::ToolCallStatus::Pending {
381 if !last_tool_use.is_input_complete
382 && last_tool_use.input.get("g").is_none()
383 {
384 saw_partial_tool_use = true;
385 }
386 } else {
387 last_tool_use
388 .input
389 .get("a")
390 .expect("'a' has streamed because input is now complete");
391 last_tool_use
392 .input
393 .get("g")
394 .expect("'g' has streamed because input is now complete");
395 }
396 } else {
397 panic!("last content should be a tool use");
398 }
399 });
400 }
401 }
402
403 assert!(
404 saw_partial_tool_use,
405 "should see at least one partially streamed tool use in the history"
406 );
407}
408
409#[gpui::test]
410async fn test_tool_authorization(cx: &mut TestAppContext) {
411 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
412 let fake_model = model.as_fake();
413
414 let mut events = thread
415 .update(cx, |thread, cx| {
416 thread.add_tool(ToolRequiringPermission);
417 thread.send(UserMessageId::new(), ["abc"], cx)
418 })
419 .unwrap();
420 cx.run_until_parked();
421 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
422 LanguageModelToolUse {
423 id: "tool_id_1".into(),
424 name: ToolRequiringPermission::name().into(),
425 raw_input: "{}".into(),
426 input: json!({}),
427 is_input_complete: true,
428 },
429 ));
430 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
431 LanguageModelToolUse {
432 id: "tool_id_2".into(),
433 name: ToolRequiringPermission::name().into(),
434 raw_input: "{}".into(),
435 input: json!({}),
436 is_input_complete: true,
437 },
438 ));
439 fake_model.end_last_completion_stream();
440 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
441 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
442
443 // Approve the first
444 tool_call_auth_1
445 .response
446 .send(tool_call_auth_1.options[1].id.clone())
447 .unwrap();
448 cx.run_until_parked();
449
450 // Reject the second
451 tool_call_auth_2
452 .response
453 .send(tool_call_auth_1.options[2].id.clone())
454 .unwrap();
455 cx.run_until_parked();
456
457 let completion = fake_model.pending_completions().pop().unwrap();
458 let message = completion.messages.last().unwrap();
459 assert_eq!(
460 message.content,
461 vec![
462 language_model::MessageContent::ToolResult(LanguageModelToolResult {
463 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
464 tool_name: ToolRequiringPermission::name().into(),
465 is_error: false,
466 content: "Allowed".into(),
467 output: Some("Allowed".into())
468 }),
469 language_model::MessageContent::ToolResult(LanguageModelToolResult {
470 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
471 tool_name: ToolRequiringPermission::name().into(),
472 is_error: true,
473 content: "Permission to run tool denied by user".into(),
474 output: Some("Permission to run tool denied by user".into())
475 })
476 ]
477 );
478
479 // Simulate yet another tool call.
480 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
481 LanguageModelToolUse {
482 id: "tool_id_3".into(),
483 name: ToolRequiringPermission::name().into(),
484 raw_input: "{}".into(),
485 input: json!({}),
486 is_input_complete: true,
487 },
488 ));
489 fake_model.end_last_completion_stream();
490
491 // Respond by always allowing tools.
492 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
493 tool_call_auth_3
494 .response
495 .send(tool_call_auth_3.options[0].id.clone())
496 .unwrap();
497 cx.run_until_parked();
498 let completion = fake_model.pending_completions().pop().unwrap();
499 let message = completion.messages.last().unwrap();
500 assert_eq!(
501 message.content,
502 vec![language_model::MessageContent::ToolResult(
503 LanguageModelToolResult {
504 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
505 tool_name: ToolRequiringPermission::name().into(),
506 is_error: false,
507 content: "Allowed".into(),
508 output: Some("Allowed".into())
509 }
510 )]
511 );
512
513 // Simulate a final tool call, ensuring we don't trigger authorization.
514 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
515 LanguageModelToolUse {
516 id: "tool_id_4".into(),
517 name: ToolRequiringPermission::name().into(),
518 raw_input: "{}".into(),
519 input: json!({}),
520 is_input_complete: true,
521 },
522 ));
523 fake_model.end_last_completion_stream();
524 cx.run_until_parked();
525 let completion = fake_model.pending_completions().pop().unwrap();
526 let message = completion.messages.last().unwrap();
527 assert_eq!(
528 message.content,
529 vec![language_model::MessageContent::ToolResult(
530 LanguageModelToolResult {
531 tool_use_id: "tool_id_4".into(),
532 tool_name: ToolRequiringPermission::name().into(),
533 is_error: false,
534 content: "Allowed".into(),
535 output: Some("Allowed".into())
536 }
537 )]
538 );
539}
540
541#[gpui::test]
542async fn test_tool_hallucination(cx: &mut TestAppContext) {
543 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
544 let fake_model = model.as_fake();
545
546 let mut events = thread
547 .update(cx, |thread, cx| {
548 thread.send(UserMessageId::new(), ["abc"], cx)
549 })
550 .unwrap();
551 cx.run_until_parked();
552 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
553 LanguageModelToolUse {
554 id: "tool_id_1".into(),
555 name: "nonexistent_tool".into(),
556 raw_input: "{}".into(),
557 input: json!({}),
558 is_input_complete: true,
559 },
560 ));
561 fake_model.end_last_completion_stream();
562
563 let tool_call = expect_tool_call(&mut events).await;
564 assert_eq!(tool_call.title, "nonexistent_tool");
565 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
566 let update = expect_tool_call_update_fields(&mut events).await;
567 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
568}
569
570#[gpui::test]
571async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
572 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
573 let fake_model = model.as_fake();
574
575 let events = thread
576 .update(cx, |thread, cx| {
577 thread.add_tool(EchoTool);
578 thread.send(UserMessageId::new(), ["abc"], cx)
579 })
580 .unwrap();
581 cx.run_until_parked();
582 let tool_use = LanguageModelToolUse {
583 id: "tool_id_1".into(),
584 name: EchoTool::name().into(),
585 raw_input: "{}".into(),
586 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
587 is_input_complete: true,
588 };
589 fake_model
590 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
591 fake_model.end_last_completion_stream();
592
593 cx.run_until_parked();
594 let completion = fake_model.pending_completions().pop().unwrap();
595 let tool_result = LanguageModelToolResult {
596 tool_use_id: "tool_id_1".into(),
597 tool_name: EchoTool::name().into(),
598 is_error: false,
599 content: "def".into(),
600 output: Some("def".into()),
601 };
602 assert_eq!(
603 completion.messages[1..],
604 vec![
605 LanguageModelRequestMessage {
606 role: Role::User,
607 content: vec!["abc".into()],
608 cache: false
609 },
610 LanguageModelRequestMessage {
611 role: Role::Assistant,
612 content: vec![MessageContent::ToolUse(tool_use.clone())],
613 cache: false
614 },
615 LanguageModelRequestMessage {
616 role: Role::User,
617 content: vec![MessageContent::ToolResult(tool_result.clone())],
618 cache: true
619 },
620 ]
621 );
622
623 // Simulate reaching tool use limit.
624 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
625 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
626 ));
627 fake_model.end_last_completion_stream();
628 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
629 assert!(
630 last_event
631 .unwrap_err()
632 .is::<language_model::ToolUseLimitReachedError>()
633 );
634
635 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
636 cx.run_until_parked();
637 let completion = fake_model.pending_completions().pop().unwrap();
638 assert_eq!(
639 completion.messages[1..],
640 vec![
641 LanguageModelRequestMessage {
642 role: Role::User,
643 content: vec!["abc".into()],
644 cache: false
645 },
646 LanguageModelRequestMessage {
647 role: Role::Assistant,
648 content: vec![MessageContent::ToolUse(tool_use)],
649 cache: false
650 },
651 LanguageModelRequestMessage {
652 role: Role::User,
653 content: vec![MessageContent::ToolResult(tool_result)],
654 cache: false
655 },
656 LanguageModelRequestMessage {
657 role: Role::User,
658 content: vec!["Continue where you left off".into()],
659 cache: true
660 }
661 ]
662 );
663
664 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
665 fake_model.end_last_completion_stream();
666 events.collect::<Vec<_>>().await;
667 thread.read_with(cx, |thread, _cx| {
668 assert_eq!(
669 thread.last_message().unwrap().to_markdown(),
670 indoc! {"
671 ## Assistant
672
673 Done
674 "}
675 )
676 });
677}
678
679#[gpui::test]
680async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
681 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
682 let fake_model = model.as_fake();
683
684 let events = thread
685 .update(cx, |thread, cx| {
686 thread.add_tool(EchoTool);
687 thread.send(UserMessageId::new(), ["abc"], cx)
688 })
689 .unwrap();
690 cx.run_until_parked();
691
692 let tool_use = LanguageModelToolUse {
693 id: "tool_id_1".into(),
694 name: EchoTool::name().into(),
695 raw_input: "{}".into(),
696 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
697 is_input_complete: true,
698 };
699 let tool_result = LanguageModelToolResult {
700 tool_use_id: "tool_id_1".into(),
701 tool_name: EchoTool::name().into(),
702 is_error: false,
703 content: "def".into(),
704 output: Some("def".into()),
705 };
706 fake_model
707 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
708 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
709 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
710 ));
711 fake_model.end_last_completion_stream();
712 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
713 assert!(
714 last_event
715 .unwrap_err()
716 .is::<language_model::ToolUseLimitReachedError>()
717 );
718
719 thread
720 .update(cx, |thread, cx| {
721 thread.send(UserMessageId::new(), vec!["ghi"], cx)
722 })
723 .unwrap();
724 cx.run_until_parked();
725 let completion = fake_model.pending_completions().pop().unwrap();
726 assert_eq!(
727 completion.messages[1..],
728 vec![
729 LanguageModelRequestMessage {
730 role: Role::User,
731 content: vec!["abc".into()],
732 cache: false
733 },
734 LanguageModelRequestMessage {
735 role: Role::Assistant,
736 content: vec![MessageContent::ToolUse(tool_use)],
737 cache: false
738 },
739 LanguageModelRequestMessage {
740 role: Role::User,
741 content: vec![MessageContent::ToolResult(tool_result)],
742 cache: false
743 },
744 LanguageModelRequestMessage {
745 role: Role::User,
746 content: vec!["ghi".into()],
747 cache: true
748 }
749 ]
750 );
751}
752
753async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
754 let event = events
755 .next()
756 .await
757 .expect("no tool call authorization event received")
758 .unwrap();
759 match event {
760 ThreadEvent::ToolCall(tool_call) => tool_call,
761 event => {
762 panic!("Unexpected event {event:?}");
763 }
764 }
765}
766
767async fn expect_tool_call_update_fields(
768 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
769) -> acp::ToolCallUpdate {
770 let event = events
771 .next()
772 .await
773 .expect("no tool call authorization event received")
774 .unwrap();
775 match event {
776 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
777 event => {
778 panic!("Unexpected event {event:?}");
779 }
780 }
781}
782
783async fn next_tool_call_authorization(
784 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
785) -> ToolCallAuthorization {
786 loop {
787 let event = events
788 .next()
789 .await
790 .expect("no tool call authorization event received")
791 .unwrap();
792 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
793 let permission_kinds = tool_call_authorization
794 .options
795 .iter()
796 .map(|o| o.kind)
797 .collect::<Vec<_>>();
798 assert_eq!(
799 permission_kinds,
800 vec![
801 acp::PermissionOptionKind::AllowAlways,
802 acp::PermissionOptionKind::AllowOnce,
803 acp::PermissionOptionKind::RejectOnce,
804 ]
805 );
806 return tool_call_authorization;
807 }
808 }
809}
810
811#[gpui::test]
812#[cfg_attr(not(feature = "e2e"), ignore)]
813async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
814 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
815
816 // Test concurrent tool calls with different delay times
817 let events = thread
818 .update(cx, |thread, cx| {
819 thread.add_tool(DelayTool);
820 thread.send(
821 UserMessageId::new(),
822 [
823 "Call the delay tool twice in the same message.",
824 "Once with 100ms. Once with 300ms.",
825 "When both timers are complete, describe the outputs.",
826 ],
827 cx,
828 )
829 })
830 .unwrap()
831 .collect()
832 .await;
833
834 let stop_reasons = stop_events(events);
835 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
836
837 thread.update(cx, |thread, _cx| {
838 let last_message = thread.last_message().unwrap();
839 let agent_message = last_message.as_agent_message().unwrap();
840 let text = agent_message
841 .content
842 .iter()
843 .filter_map(|content| {
844 if let AgentMessageContent::Text(text) = content {
845 Some(text.as_str())
846 } else {
847 None
848 }
849 })
850 .collect::<String>();
851
852 assert!(text.contains("Ding"));
853 });
854}
855
856#[gpui::test]
857async fn test_profiles(cx: &mut TestAppContext) {
858 let ThreadTest {
859 model, thread, fs, ..
860 } = setup(cx, TestModel::Fake).await;
861 let fake_model = model.as_fake();
862
863 thread.update(cx, |thread, _cx| {
864 thread.add_tool(DelayTool);
865 thread.add_tool(EchoTool);
866 thread.add_tool(InfiniteTool);
867 });
868
869 // Override profiles and wait for settings to be loaded.
870 fs.insert_file(
871 paths::settings_file(),
872 json!({
873 "agent": {
874 "profiles": {
875 "test-1": {
876 "name": "Test Profile 1",
877 "tools": {
878 EchoTool::name(): true,
879 DelayTool::name(): true,
880 }
881 },
882 "test-2": {
883 "name": "Test Profile 2",
884 "tools": {
885 InfiniteTool::name(): true,
886 }
887 }
888 }
889 }
890 })
891 .to_string()
892 .into_bytes(),
893 )
894 .await;
895 cx.run_until_parked();
896
897 // Test that test-1 profile (default) has echo and delay tools
898 thread
899 .update(cx, |thread, cx| {
900 thread.set_profile(AgentProfileId("test-1".into()));
901 thread.send(UserMessageId::new(), ["test"], cx)
902 })
903 .unwrap();
904 cx.run_until_parked();
905
906 let mut pending_completions = fake_model.pending_completions();
907 assert_eq!(pending_completions.len(), 1);
908 let completion = pending_completions.pop().unwrap();
909 let tool_names: Vec<String> = completion
910 .tools
911 .iter()
912 .map(|tool| tool.name.clone())
913 .collect();
914 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
915 fake_model.end_last_completion_stream();
916
917 // Switch to test-2 profile, and verify that it has only the infinite tool.
918 thread
919 .update(cx, |thread, cx| {
920 thread.set_profile(AgentProfileId("test-2".into()));
921 thread.send(UserMessageId::new(), ["test2"], cx)
922 })
923 .unwrap();
924 cx.run_until_parked();
925 let mut pending_completions = fake_model.pending_completions();
926 assert_eq!(pending_completions.len(), 1);
927 let completion = pending_completions.pop().unwrap();
928 let tool_names: Vec<String> = completion
929 .tools
930 .iter()
931 .map(|tool| tool.name.clone())
932 .collect();
933 assert_eq!(tool_names, vec![InfiniteTool::name()]);
934}
935
936#[gpui::test]
937async fn test_mcp_tools(cx: &mut TestAppContext) {
938 let ThreadTest {
939 model,
940 thread,
941 context_server_store,
942 fs,
943 ..
944 } = setup(cx, TestModel::Fake).await;
945 let fake_model = model.as_fake();
946
947 // Override profiles and wait for settings to be loaded.
948 fs.insert_file(
949 paths::settings_file(),
950 json!({
951 "agent": {
952 "always_allow_tool_actions": true,
953 "profiles": {
954 "test": {
955 "name": "Test Profile",
956 "enable_all_context_servers": true,
957 "tools": {
958 EchoTool::name(): true,
959 }
960 },
961 }
962 }
963 })
964 .to_string()
965 .into_bytes(),
966 )
967 .await;
968 cx.run_until_parked();
969 thread.update(cx, |thread, _| {
970 thread.set_profile(AgentProfileId("test".into()))
971 });
972
973 let mut mcp_tool_calls = setup_context_server(
974 "test_server",
975 vec![context_server::types::Tool {
976 name: "echo".into(),
977 description: None,
978 input_schema: serde_json::to_value(EchoTool::input_schema(
979 LanguageModelToolSchemaFormat::JsonSchema,
980 ))
981 .unwrap(),
982 output_schema: None,
983 annotations: None,
984 }],
985 &context_server_store,
986 cx,
987 );
988
989 let events = thread.update(cx, |thread, cx| {
990 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
991 });
992 cx.run_until_parked();
993
994 // Simulate the model calling the MCP tool.
995 let completion = fake_model.pending_completions().pop().unwrap();
996 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
997 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
998 LanguageModelToolUse {
999 id: "tool_1".into(),
1000 name: "echo".into(),
1001 raw_input: json!({"text": "test"}).to_string(),
1002 input: json!({"text": "test"}),
1003 is_input_complete: true,
1004 },
1005 ));
1006 fake_model.end_last_completion_stream();
1007 cx.run_until_parked();
1008
1009 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1010 assert_eq!(tool_call_params.name, "echo");
1011 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1012 tool_call_response
1013 .send(context_server::types::CallToolResponse {
1014 content: vec![context_server::types::ToolResponseContent::Text {
1015 text: "test".into(),
1016 }],
1017 is_error: None,
1018 meta: None,
1019 structured_content: None,
1020 })
1021 .unwrap();
1022 cx.run_until_parked();
1023
1024 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1025 fake_model.send_last_completion_stream_text_chunk("Done!");
1026 fake_model.end_last_completion_stream();
1027 events.collect::<Vec<_>>().await;
1028
1029 // Send again after adding the echo tool, ensuring the name collision is resolved.
1030 let events = thread.update(cx, |thread, cx| {
1031 thread.add_tool(EchoTool);
1032 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1033 });
1034 cx.run_until_parked();
1035 let completion = fake_model.pending_completions().pop().unwrap();
1036 assert_eq!(
1037 tool_names_for_completion(&completion),
1038 vec!["echo", "test_server_echo"]
1039 );
1040 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1041 LanguageModelToolUse {
1042 id: "tool_2".into(),
1043 name: "test_server_echo".into(),
1044 raw_input: json!({"text": "mcp"}).to_string(),
1045 input: json!({"text": "mcp"}),
1046 is_input_complete: true,
1047 },
1048 ));
1049 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1050 LanguageModelToolUse {
1051 id: "tool_3".into(),
1052 name: "echo".into(),
1053 raw_input: json!({"text": "native"}).to_string(),
1054 input: json!({"text": "native"}),
1055 is_input_complete: true,
1056 },
1057 ));
1058 fake_model.end_last_completion_stream();
1059 cx.run_until_parked();
1060
1061 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1062 assert_eq!(tool_call_params.name, "echo");
1063 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1064 tool_call_response
1065 .send(context_server::types::CallToolResponse {
1066 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1067 is_error: None,
1068 meta: None,
1069 structured_content: None,
1070 })
1071 .unwrap();
1072 cx.run_until_parked();
1073
1074 // Ensure the tool results were inserted with the correct names.
1075 let completion = fake_model.pending_completions().pop().unwrap();
1076 assert_eq!(
1077 completion.messages.last().unwrap().content,
1078 vec![
1079 MessageContent::ToolResult(LanguageModelToolResult {
1080 tool_use_id: "tool_3".into(),
1081 tool_name: "echo".into(),
1082 is_error: false,
1083 content: "native".into(),
1084 output: Some("native".into()),
1085 },),
1086 MessageContent::ToolResult(LanguageModelToolResult {
1087 tool_use_id: "tool_2".into(),
1088 tool_name: "test_server_echo".into(),
1089 is_error: false,
1090 content: "mcp".into(),
1091 output: Some("mcp".into()),
1092 },),
1093 ]
1094 );
1095 fake_model.end_last_completion_stream();
1096 events.collect::<Vec<_>>().await;
1097}
1098
1099#[gpui::test]
1100async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1101 let ThreadTest {
1102 model,
1103 thread,
1104 context_server_store,
1105 fs,
1106 ..
1107 } = setup(cx, TestModel::Fake).await;
1108 let fake_model = model.as_fake();
1109
1110 // Set up a profile with all tools enabled
1111 fs.insert_file(
1112 paths::settings_file(),
1113 json!({
1114 "agent": {
1115 "profiles": {
1116 "test": {
1117 "name": "Test Profile",
1118 "enable_all_context_servers": true,
1119 "tools": {
1120 EchoTool::name(): true,
1121 DelayTool::name(): true,
1122 WordListTool::name(): true,
1123 ToolRequiringPermission::name(): true,
1124 InfiniteTool::name(): true,
1125 }
1126 },
1127 }
1128 }
1129 })
1130 .to_string()
1131 .into_bytes(),
1132 )
1133 .await;
1134 cx.run_until_parked();
1135
1136 thread.update(cx, |thread, _| {
1137 thread.set_profile(AgentProfileId("test".into()));
1138 thread.add_tool(EchoTool);
1139 thread.add_tool(DelayTool);
1140 thread.add_tool(WordListTool);
1141 thread.add_tool(ToolRequiringPermission);
1142 thread.add_tool(InfiniteTool);
1143 });
1144
1145 // Set up multiple context servers with some overlapping tool names
1146 let _server1_calls = setup_context_server(
1147 "xxx",
1148 vec![
1149 context_server::types::Tool {
1150 name: "echo".into(), // Conflicts with native EchoTool
1151 description: None,
1152 input_schema: serde_json::to_value(EchoTool::input_schema(
1153 LanguageModelToolSchemaFormat::JsonSchema,
1154 ))
1155 .unwrap(),
1156 output_schema: None,
1157 annotations: None,
1158 },
1159 context_server::types::Tool {
1160 name: "unique_tool_1".into(),
1161 description: None,
1162 input_schema: json!({"type": "object", "properties": {}}),
1163 output_schema: None,
1164 annotations: None,
1165 },
1166 ],
1167 &context_server_store,
1168 cx,
1169 );
1170
1171 let _server2_calls = setup_context_server(
1172 "yyy",
1173 vec![
1174 context_server::types::Tool {
1175 name: "echo".into(), // Also conflicts with native EchoTool
1176 description: None,
1177 input_schema: serde_json::to_value(EchoTool::input_schema(
1178 LanguageModelToolSchemaFormat::JsonSchema,
1179 ))
1180 .unwrap(),
1181 output_schema: None,
1182 annotations: None,
1183 },
1184 context_server::types::Tool {
1185 name: "unique_tool_2".into(),
1186 description: None,
1187 input_schema: json!({"type": "object", "properties": {}}),
1188 output_schema: None,
1189 annotations: None,
1190 },
1191 context_server::types::Tool {
1192 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1193 description: None,
1194 input_schema: json!({"type": "object", "properties": {}}),
1195 output_schema: None,
1196 annotations: None,
1197 },
1198 context_server::types::Tool {
1199 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1200 description: None,
1201 input_schema: json!({"type": "object", "properties": {}}),
1202 output_schema: None,
1203 annotations: None,
1204 },
1205 ],
1206 &context_server_store,
1207 cx,
1208 );
1209 let _server3_calls = setup_context_server(
1210 "zzz",
1211 vec![
1212 context_server::types::Tool {
1213 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1214 description: None,
1215 input_schema: json!({"type": "object", "properties": {}}),
1216 output_schema: None,
1217 annotations: None,
1218 },
1219 context_server::types::Tool {
1220 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1221 description: None,
1222 input_schema: json!({"type": "object", "properties": {}}),
1223 output_schema: None,
1224 annotations: None,
1225 },
1226 context_server::types::Tool {
1227 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1228 description: None,
1229 input_schema: json!({"type": "object", "properties": {}}),
1230 output_schema: None,
1231 annotations: None,
1232 },
1233 ],
1234 &context_server_store,
1235 cx,
1236 );
1237
1238 thread
1239 .update(cx, |thread, cx| {
1240 thread.send(UserMessageId::new(), ["Go"], cx)
1241 })
1242 .unwrap();
1243 cx.run_until_parked();
1244 let completion = fake_model.pending_completions().pop().unwrap();
1245 assert_eq!(
1246 tool_names_for_completion(&completion),
1247 vec![
1248 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1249 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1250 "delay",
1251 "echo",
1252 "infinite",
1253 "tool_requiring_permission",
1254 "unique_tool_1",
1255 "unique_tool_2",
1256 "word_list",
1257 "xxx_echo",
1258 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1259 "yyy_echo",
1260 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1261 ]
1262 );
1263}
1264
1265#[gpui::test]
1266#[cfg_attr(not(feature = "e2e"), ignore)]
1267async fn test_cancellation(cx: &mut TestAppContext) {
1268 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1269
1270 let mut events = thread
1271 .update(cx, |thread, cx| {
1272 thread.add_tool(InfiniteTool);
1273 thread.add_tool(EchoTool);
1274 thread.send(
1275 UserMessageId::new(),
1276 ["Call the echo tool, then call the infinite tool, then explain their output"],
1277 cx,
1278 )
1279 })
1280 .unwrap();
1281
1282 // Wait until both tools are called.
1283 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1284 let mut echo_id = None;
1285 let mut echo_completed = false;
1286 while let Some(event) = events.next().await {
1287 match event.unwrap() {
1288 ThreadEvent::ToolCall(tool_call) => {
1289 assert_eq!(tool_call.title, expected_tools.remove(0));
1290 if tool_call.title == "Echo" {
1291 echo_id = Some(tool_call.id);
1292 }
1293 }
1294 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1295 acp::ToolCallUpdate {
1296 id,
1297 fields:
1298 acp::ToolCallUpdateFields {
1299 status: Some(acp::ToolCallStatus::Completed),
1300 ..
1301 },
1302 meta: None,
1303 },
1304 )) if Some(&id) == echo_id.as_ref() => {
1305 echo_completed = true;
1306 }
1307 _ => {}
1308 }
1309
1310 if expected_tools.is_empty() && echo_completed {
1311 break;
1312 }
1313 }
1314
1315 // Cancel the current send and ensure that the event stream is closed, even
1316 // if one of the tools is still running.
1317 thread.update(cx, |thread, cx| thread.cancel(cx));
1318 let events = events.collect::<Vec<_>>().await;
1319 let last_event = events.last();
1320 assert!(
1321 matches!(
1322 last_event,
1323 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1324 ),
1325 "unexpected event {last_event:?}"
1326 );
1327
1328 // Ensure we can still send a new message after cancellation.
1329 let events = thread
1330 .update(cx, |thread, cx| {
1331 thread.send(
1332 UserMessageId::new(),
1333 ["Testing: reply with 'Hello' then stop."],
1334 cx,
1335 )
1336 })
1337 .unwrap()
1338 .collect::<Vec<_>>()
1339 .await;
1340 thread.update(cx, |thread, _cx| {
1341 let message = thread.last_message().unwrap();
1342 let agent_message = message.as_agent_message().unwrap();
1343 assert_eq!(
1344 agent_message.content,
1345 vec![AgentMessageContent::Text("Hello".to_string())]
1346 );
1347 });
1348 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1349}
1350
1351#[gpui::test]
1352async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1353 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1354 let fake_model = model.as_fake();
1355
1356 let events_1 = thread
1357 .update(cx, |thread, cx| {
1358 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1359 })
1360 .unwrap();
1361 cx.run_until_parked();
1362 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1363 cx.run_until_parked();
1364
1365 let events_2 = thread
1366 .update(cx, |thread, cx| {
1367 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1368 })
1369 .unwrap();
1370 cx.run_until_parked();
1371 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1372 fake_model
1373 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1374 fake_model.end_last_completion_stream();
1375
1376 let events_1 = events_1.collect::<Vec<_>>().await;
1377 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1378 let events_2 = events_2.collect::<Vec<_>>().await;
1379 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1380}
1381
1382#[gpui::test]
1383async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1384 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1385 let fake_model = model.as_fake();
1386
1387 let events_1 = thread
1388 .update(cx, |thread, cx| {
1389 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1390 })
1391 .unwrap();
1392 cx.run_until_parked();
1393 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1394 fake_model
1395 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1396 fake_model.end_last_completion_stream();
1397 let events_1 = events_1.collect::<Vec<_>>().await;
1398
1399 let events_2 = thread
1400 .update(cx, |thread, cx| {
1401 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1402 })
1403 .unwrap();
1404 cx.run_until_parked();
1405 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1406 fake_model
1407 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1408 fake_model.end_last_completion_stream();
1409 let events_2 = events_2.collect::<Vec<_>>().await;
1410
1411 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1412 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1413}
1414
1415#[gpui::test]
1416async fn test_refusal(cx: &mut TestAppContext) {
1417 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1418 let fake_model = model.as_fake();
1419
1420 let events = thread
1421 .update(cx, |thread, cx| {
1422 thread.send(UserMessageId::new(), ["Hello"], cx)
1423 })
1424 .unwrap();
1425 cx.run_until_parked();
1426 thread.read_with(cx, |thread, _| {
1427 assert_eq!(
1428 thread.to_markdown(),
1429 indoc! {"
1430 ## User
1431
1432 Hello
1433 "}
1434 );
1435 });
1436
1437 fake_model.send_last_completion_stream_text_chunk("Hey!");
1438 cx.run_until_parked();
1439 thread.read_with(cx, |thread, _| {
1440 assert_eq!(
1441 thread.to_markdown(),
1442 indoc! {"
1443 ## User
1444
1445 Hello
1446
1447 ## Assistant
1448
1449 Hey!
1450 "}
1451 );
1452 });
1453
1454 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1455 fake_model
1456 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1457 let events = events.collect::<Vec<_>>().await;
1458 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1459 thread.read_with(cx, |thread, _| {
1460 assert_eq!(thread.to_markdown(), "");
1461 });
1462}
1463
1464#[gpui::test]
1465async fn test_truncate_first_message(cx: &mut TestAppContext) {
1466 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1467 let fake_model = model.as_fake();
1468
1469 let message_id = UserMessageId::new();
1470 thread
1471 .update(cx, |thread, cx| {
1472 thread.send(message_id.clone(), ["Hello"], cx)
1473 })
1474 .unwrap();
1475 cx.run_until_parked();
1476 thread.read_with(cx, |thread, _| {
1477 assert_eq!(
1478 thread.to_markdown(),
1479 indoc! {"
1480 ## User
1481
1482 Hello
1483 "}
1484 );
1485 assert_eq!(thread.latest_token_usage(), None);
1486 });
1487
1488 fake_model.send_last_completion_stream_text_chunk("Hey!");
1489 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1490 language_model::TokenUsage {
1491 input_tokens: 32_000,
1492 output_tokens: 16_000,
1493 cache_creation_input_tokens: 0,
1494 cache_read_input_tokens: 0,
1495 },
1496 ));
1497 cx.run_until_parked();
1498 thread.read_with(cx, |thread, _| {
1499 assert_eq!(
1500 thread.to_markdown(),
1501 indoc! {"
1502 ## User
1503
1504 Hello
1505
1506 ## Assistant
1507
1508 Hey!
1509 "}
1510 );
1511 assert_eq!(
1512 thread.latest_token_usage(),
1513 Some(acp_thread::TokenUsage {
1514 used_tokens: 32_000 + 16_000,
1515 max_tokens: 1_000_000,
1516 })
1517 );
1518 });
1519
1520 thread
1521 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1522 .unwrap();
1523 cx.run_until_parked();
1524 thread.read_with(cx, |thread, _| {
1525 assert_eq!(thread.to_markdown(), "");
1526 assert_eq!(thread.latest_token_usage(), None);
1527 });
1528
1529 // Ensure we can still send a new message after truncation.
1530 thread
1531 .update(cx, |thread, cx| {
1532 thread.send(UserMessageId::new(), ["Hi"], cx)
1533 })
1534 .unwrap();
1535 thread.update(cx, |thread, _cx| {
1536 assert_eq!(
1537 thread.to_markdown(),
1538 indoc! {"
1539 ## User
1540
1541 Hi
1542 "}
1543 );
1544 });
1545 cx.run_until_parked();
1546 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1547 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1548 language_model::TokenUsage {
1549 input_tokens: 40_000,
1550 output_tokens: 20_000,
1551 cache_creation_input_tokens: 0,
1552 cache_read_input_tokens: 0,
1553 },
1554 ));
1555 cx.run_until_parked();
1556 thread.read_with(cx, |thread, _| {
1557 assert_eq!(
1558 thread.to_markdown(),
1559 indoc! {"
1560 ## User
1561
1562 Hi
1563
1564 ## Assistant
1565
1566 Ahoy!
1567 "}
1568 );
1569
1570 assert_eq!(
1571 thread.latest_token_usage(),
1572 Some(acp_thread::TokenUsage {
1573 used_tokens: 40_000 + 20_000,
1574 max_tokens: 1_000_000,
1575 })
1576 );
1577 });
1578}
1579
1580#[gpui::test]
1581async fn test_truncate_second_message(cx: &mut TestAppContext) {
1582 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1583 let fake_model = model.as_fake();
1584
1585 thread
1586 .update(cx, |thread, cx| {
1587 thread.send(UserMessageId::new(), ["Message 1"], cx)
1588 })
1589 .unwrap();
1590 cx.run_until_parked();
1591 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1592 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1593 language_model::TokenUsage {
1594 input_tokens: 32_000,
1595 output_tokens: 16_000,
1596 cache_creation_input_tokens: 0,
1597 cache_read_input_tokens: 0,
1598 },
1599 ));
1600 fake_model.end_last_completion_stream();
1601 cx.run_until_parked();
1602
1603 let assert_first_message_state = |cx: &mut TestAppContext| {
1604 thread.clone().read_with(cx, |thread, _| {
1605 assert_eq!(
1606 thread.to_markdown(),
1607 indoc! {"
1608 ## User
1609
1610 Message 1
1611
1612 ## Assistant
1613
1614 Message 1 response
1615 "}
1616 );
1617
1618 assert_eq!(
1619 thread.latest_token_usage(),
1620 Some(acp_thread::TokenUsage {
1621 used_tokens: 32_000 + 16_000,
1622 max_tokens: 1_000_000,
1623 })
1624 );
1625 });
1626 };
1627
1628 assert_first_message_state(cx);
1629
1630 let second_message_id = UserMessageId::new();
1631 thread
1632 .update(cx, |thread, cx| {
1633 thread.send(second_message_id.clone(), ["Message 2"], cx)
1634 })
1635 .unwrap();
1636 cx.run_until_parked();
1637
1638 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1639 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1640 language_model::TokenUsage {
1641 input_tokens: 40_000,
1642 output_tokens: 20_000,
1643 cache_creation_input_tokens: 0,
1644 cache_read_input_tokens: 0,
1645 },
1646 ));
1647 fake_model.end_last_completion_stream();
1648 cx.run_until_parked();
1649
1650 thread.read_with(cx, |thread, _| {
1651 assert_eq!(
1652 thread.to_markdown(),
1653 indoc! {"
1654 ## User
1655
1656 Message 1
1657
1658 ## Assistant
1659
1660 Message 1 response
1661
1662 ## User
1663
1664 Message 2
1665
1666 ## Assistant
1667
1668 Message 2 response
1669 "}
1670 );
1671
1672 assert_eq!(
1673 thread.latest_token_usage(),
1674 Some(acp_thread::TokenUsage {
1675 used_tokens: 40_000 + 20_000,
1676 max_tokens: 1_000_000,
1677 })
1678 );
1679 });
1680
1681 thread
1682 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1683 .unwrap();
1684 cx.run_until_parked();
1685
1686 assert_first_message_state(cx);
1687}
1688
1689#[gpui::test]
1690async fn test_title_generation(cx: &mut TestAppContext) {
1691 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1692 let fake_model = model.as_fake();
1693
1694 let summary_model = Arc::new(FakeLanguageModel::default());
1695 thread.update(cx, |thread, cx| {
1696 thread.set_summarization_model(Some(summary_model.clone()), cx)
1697 });
1698
1699 let send = thread
1700 .update(cx, |thread, cx| {
1701 thread.send(UserMessageId::new(), ["Hello"], cx)
1702 })
1703 .unwrap();
1704 cx.run_until_parked();
1705
1706 fake_model.send_last_completion_stream_text_chunk("Hey!");
1707 fake_model.end_last_completion_stream();
1708 cx.run_until_parked();
1709 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1710
1711 // Ensure the summary model has been invoked to generate a title.
1712 summary_model.send_last_completion_stream_text_chunk("Hello ");
1713 summary_model.send_last_completion_stream_text_chunk("world\nG");
1714 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1715 summary_model.end_last_completion_stream();
1716 send.collect::<Vec<_>>().await;
1717 cx.run_until_parked();
1718 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1719
1720 // Send another message, ensuring no title is generated this time.
1721 let send = thread
1722 .update(cx, |thread, cx| {
1723 thread.send(UserMessageId::new(), ["Hello again"], cx)
1724 })
1725 .unwrap();
1726 cx.run_until_parked();
1727 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1728 fake_model.end_last_completion_stream();
1729 cx.run_until_parked();
1730 assert_eq!(summary_model.pending_completions(), Vec::new());
1731 send.collect::<Vec<_>>().await;
1732 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1733}
1734
1735#[gpui::test]
1736async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1737 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1738 let fake_model = model.as_fake();
1739
1740 let _events = thread
1741 .update(cx, |thread, cx| {
1742 thread.add_tool(ToolRequiringPermission);
1743 thread.add_tool(EchoTool);
1744 thread.send(UserMessageId::new(), ["Hey!"], cx)
1745 })
1746 .unwrap();
1747 cx.run_until_parked();
1748
1749 let permission_tool_use = LanguageModelToolUse {
1750 id: "tool_id_1".into(),
1751 name: ToolRequiringPermission::name().into(),
1752 raw_input: "{}".into(),
1753 input: json!({}),
1754 is_input_complete: true,
1755 };
1756 let echo_tool_use = LanguageModelToolUse {
1757 id: "tool_id_2".into(),
1758 name: EchoTool::name().into(),
1759 raw_input: json!({"text": "test"}).to_string(),
1760 input: json!({"text": "test"}),
1761 is_input_complete: true,
1762 };
1763 fake_model.send_last_completion_stream_text_chunk("Hi!");
1764 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1765 permission_tool_use,
1766 ));
1767 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1768 echo_tool_use.clone(),
1769 ));
1770 fake_model.end_last_completion_stream();
1771 cx.run_until_parked();
1772
1773 // Ensure pending tools are skipped when building a request.
1774 let request = thread
1775 .read_with(cx, |thread, cx| {
1776 thread.build_completion_request(CompletionIntent::EditFile, cx)
1777 })
1778 .unwrap();
1779 assert_eq!(
1780 request.messages[1..],
1781 vec![
1782 LanguageModelRequestMessage {
1783 role: Role::User,
1784 content: vec!["Hey!".into()],
1785 cache: true
1786 },
1787 LanguageModelRequestMessage {
1788 role: Role::Assistant,
1789 content: vec![
1790 MessageContent::Text("Hi!".into()),
1791 MessageContent::ToolUse(echo_tool_use.clone())
1792 ],
1793 cache: false
1794 },
1795 LanguageModelRequestMessage {
1796 role: Role::User,
1797 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1798 tool_use_id: echo_tool_use.id.clone(),
1799 tool_name: echo_tool_use.name,
1800 is_error: false,
1801 content: "test".into(),
1802 output: Some("test".into())
1803 })],
1804 cache: false
1805 },
1806 ],
1807 );
1808}
1809
1810#[gpui::test]
1811async fn test_agent_connection(cx: &mut TestAppContext) {
1812 cx.update(settings::init);
1813 let templates = Templates::new();
1814
1815 // Initialize language model system with test provider
1816 cx.update(|cx| {
1817 gpui_tokio::init(cx);
1818 client::init_settings(cx);
1819
1820 let http_client = FakeHttpClient::with_404_response();
1821 let clock = Arc::new(clock::FakeSystemClock::new());
1822 let client = Client::new(clock, http_client, cx);
1823 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1824 language_model::init(client.clone(), cx);
1825 language_models::init(user_store, client.clone(), cx);
1826 Project::init_settings(cx);
1827 LanguageModelRegistry::test(cx);
1828 agent_settings::init(cx);
1829 });
1830 cx.executor().forbid_parking();
1831
1832 // Create a project for new_thread
1833 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1834 fake_fs.insert_tree(path!("/test"), json!({})).await;
1835 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1836 let cwd = Path::new("/test");
1837 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1838 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1839
1840 // Create agent and connection
1841 let agent = NativeAgent::new(
1842 project.clone(),
1843 history_store,
1844 templates.clone(),
1845 None,
1846 fake_fs.clone(),
1847 &mut cx.to_async(),
1848 )
1849 .await
1850 .unwrap();
1851 let connection = NativeAgentConnection(agent.clone());
1852
1853 // Create a thread using new_thread
1854 let connection_rc = Rc::new(connection.clone());
1855 let acp_thread = cx
1856 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1857 .await
1858 .expect("new_thread should succeed");
1859
1860 // Get the session_id from the AcpThread
1861 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1862
1863 // Test model_selector returns Some
1864 let selector_opt = connection.model_selector(&session_id);
1865 assert!(
1866 selector_opt.is_some(),
1867 "agent should always support ModelSelector"
1868 );
1869 let selector = selector_opt.unwrap();
1870
1871 // Test list_models
1872 let listed_models = cx
1873 .update(|cx| selector.list_models(cx))
1874 .await
1875 .expect("list_models should succeed");
1876 let AgentModelList::Grouped(listed_models) = listed_models else {
1877 panic!("Unexpected model list type");
1878 };
1879 assert!(!listed_models.is_empty(), "should have at least one model");
1880 assert_eq!(
1881 listed_models[&AgentModelGroupName("Fake".into())][0]
1882 .id
1883 .0
1884 .as_ref(),
1885 "fake/fake"
1886 );
1887
1888 // Test selected_model returns the default
1889 let model = cx
1890 .update(|cx| selector.selected_model(cx))
1891 .await
1892 .expect("selected_model should succeed");
1893 let model = cx
1894 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1895 .unwrap();
1896 let model = model.as_fake();
1897 assert_eq!(model.id().0, "fake", "should return default model");
1898
1899 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1900 cx.run_until_parked();
1901 model.send_last_completion_stream_text_chunk("def");
1902 cx.run_until_parked();
1903 acp_thread.read_with(cx, |thread, cx| {
1904 assert_eq!(
1905 thread.to_markdown(cx),
1906 indoc! {"
1907 ## User
1908
1909 abc
1910
1911 ## Assistant
1912
1913 def
1914
1915 "}
1916 )
1917 });
1918
1919 // Test cancel
1920 cx.update(|cx| connection.cancel(&session_id, cx));
1921 request.await.expect("prompt should fail gracefully");
1922
1923 // Ensure that dropping the ACP thread causes the native thread to be
1924 // dropped as well.
1925 cx.update(|_| drop(acp_thread));
1926 let result = cx
1927 .update(|cx| {
1928 connection.prompt(
1929 Some(acp_thread::UserMessageId::new()),
1930 acp::PromptRequest {
1931 session_id: session_id.clone(),
1932 prompt: vec!["ghi".into()],
1933 meta: None,
1934 },
1935 cx,
1936 )
1937 })
1938 .await;
1939 assert_eq!(
1940 result.as_ref().unwrap_err().to_string(),
1941 "Session not found",
1942 "unexpected result: {:?}",
1943 result
1944 );
1945}
1946
1947#[gpui::test]
1948async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1949 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1950 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1951 let fake_model = model.as_fake();
1952
1953 let mut events = thread
1954 .update(cx, |thread, cx| {
1955 thread.send(UserMessageId::new(), ["Think"], cx)
1956 })
1957 .unwrap();
1958 cx.run_until_parked();
1959
1960 // Simulate streaming partial input.
1961 let input = json!({});
1962 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1963 LanguageModelToolUse {
1964 id: "1".into(),
1965 name: ThinkingTool::name().into(),
1966 raw_input: input.to_string(),
1967 input,
1968 is_input_complete: false,
1969 },
1970 ));
1971
1972 // Input streaming completed
1973 let input = json!({ "content": "Thinking hard!" });
1974 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1975 LanguageModelToolUse {
1976 id: "1".into(),
1977 name: "thinking".into(),
1978 raw_input: input.to_string(),
1979 input,
1980 is_input_complete: true,
1981 },
1982 ));
1983 fake_model.end_last_completion_stream();
1984 cx.run_until_parked();
1985
1986 let tool_call = expect_tool_call(&mut events).await;
1987 assert_eq!(
1988 tool_call,
1989 acp::ToolCall {
1990 id: acp::ToolCallId("1".into()),
1991 title: "Thinking".into(),
1992 kind: acp::ToolKind::Think,
1993 status: acp::ToolCallStatus::Pending,
1994 content: vec![],
1995 locations: vec![],
1996 raw_input: Some(json!({})),
1997 raw_output: None,
1998 meta: Some(json!({ "tool_name": "thinking" })),
1999 }
2000 );
2001 let update = expect_tool_call_update_fields(&mut events).await;
2002 assert_eq!(
2003 update,
2004 acp::ToolCallUpdate {
2005 id: acp::ToolCallId("1".into()),
2006 fields: acp::ToolCallUpdateFields {
2007 title: Some("Thinking".into()),
2008 kind: Some(acp::ToolKind::Think),
2009 raw_input: Some(json!({ "content": "Thinking hard!" })),
2010 ..Default::default()
2011 },
2012 meta: None,
2013 }
2014 );
2015 let update = expect_tool_call_update_fields(&mut events).await;
2016 assert_eq!(
2017 update,
2018 acp::ToolCallUpdate {
2019 id: acp::ToolCallId("1".into()),
2020 fields: acp::ToolCallUpdateFields {
2021 status: Some(acp::ToolCallStatus::InProgress),
2022 ..Default::default()
2023 },
2024 meta: None,
2025 }
2026 );
2027 let update = expect_tool_call_update_fields(&mut events).await;
2028 assert_eq!(
2029 update,
2030 acp::ToolCallUpdate {
2031 id: acp::ToolCallId("1".into()),
2032 fields: acp::ToolCallUpdateFields {
2033 content: Some(vec!["Thinking hard!".into()]),
2034 ..Default::default()
2035 },
2036 meta: None,
2037 }
2038 );
2039 let update = expect_tool_call_update_fields(&mut events).await;
2040 assert_eq!(
2041 update,
2042 acp::ToolCallUpdate {
2043 id: acp::ToolCallId("1".into()),
2044 fields: acp::ToolCallUpdateFields {
2045 status: Some(acp::ToolCallStatus::Completed),
2046 raw_output: Some("Finished thinking.".into()),
2047 ..Default::default()
2048 },
2049 meta: None,
2050 }
2051 );
2052}
2053
2054#[gpui::test]
2055async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2056 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2057 let fake_model = model.as_fake();
2058
2059 let mut events = thread
2060 .update(cx, |thread, cx| {
2061 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2062 thread.send(UserMessageId::new(), ["Hello!"], cx)
2063 })
2064 .unwrap();
2065 cx.run_until_parked();
2066
2067 fake_model.send_last_completion_stream_text_chunk("Hey!");
2068 fake_model.end_last_completion_stream();
2069
2070 let mut retry_events = Vec::new();
2071 while let Some(Ok(event)) = events.next().await {
2072 match event {
2073 ThreadEvent::Retry(retry_status) => {
2074 retry_events.push(retry_status);
2075 }
2076 ThreadEvent::Stop(..) => break,
2077 _ => {}
2078 }
2079 }
2080
2081 assert_eq!(retry_events.len(), 0);
2082 thread.read_with(cx, |thread, _cx| {
2083 assert_eq!(
2084 thread.to_markdown(),
2085 indoc! {"
2086 ## User
2087
2088 Hello!
2089
2090 ## Assistant
2091
2092 Hey!
2093 "}
2094 )
2095 });
2096}
2097
2098#[gpui::test]
2099async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2100 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2101 let fake_model = model.as_fake();
2102
2103 let mut events = thread
2104 .update(cx, |thread, cx| {
2105 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2106 thread.send(UserMessageId::new(), ["Hello!"], cx)
2107 })
2108 .unwrap();
2109 cx.run_until_parked();
2110
2111 fake_model.send_last_completion_stream_text_chunk("Hey,");
2112 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2113 provider: LanguageModelProviderName::new("Anthropic"),
2114 retry_after: Some(Duration::from_secs(3)),
2115 });
2116 fake_model.end_last_completion_stream();
2117
2118 cx.executor().advance_clock(Duration::from_secs(3));
2119 cx.run_until_parked();
2120
2121 fake_model.send_last_completion_stream_text_chunk("there!");
2122 fake_model.end_last_completion_stream();
2123 cx.run_until_parked();
2124
2125 let mut retry_events = Vec::new();
2126 while let Some(Ok(event)) = events.next().await {
2127 match event {
2128 ThreadEvent::Retry(retry_status) => {
2129 retry_events.push(retry_status);
2130 }
2131 ThreadEvent::Stop(..) => break,
2132 _ => {}
2133 }
2134 }
2135
2136 assert_eq!(retry_events.len(), 1);
2137 assert!(matches!(
2138 retry_events[0],
2139 acp_thread::RetryStatus { attempt: 1, .. }
2140 ));
2141 thread.read_with(cx, |thread, _cx| {
2142 assert_eq!(
2143 thread.to_markdown(),
2144 indoc! {"
2145 ## User
2146
2147 Hello!
2148
2149 ## Assistant
2150
2151 Hey,
2152
2153 [resume]
2154
2155 ## Assistant
2156
2157 there!
2158 "}
2159 )
2160 });
2161}
2162
2163#[gpui::test]
2164async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2165 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2166 let fake_model = model.as_fake();
2167
2168 let events = thread
2169 .update(cx, |thread, cx| {
2170 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2171 thread.add_tool(EchoTool);
2172 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2173 })
2174 .unwrap();
2175 cx.run_until_parked();
2176
2177 let tool_use_1 = LanguageModelToolUse {
2178 id: "tool_1".into(),
2179 name: EchoTool::name().into(),
2180 raw_input: json!({"text": "test"}).to_string(),
2181 input: json!({"text": "test"}),
2182 is_input_complete: true,
2183 };
2184 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2185 tool_use_1.clone(),
2186 ));
2187 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2188 provider: LanguageModelProviderName::new("Anthropic"),
2189 retry_after: Some(Duration::from_secs(3)),
2190 });
2191 fake_model.end_last_completion_stream();
2192
2193 cx.executor().advance_clock(Duration::from_secs(3));
2194 let completion = fake_model.pending_completions().pop().unwrap();
2195 assert_eq!(
2196 completion.messages[1..],
2197 vec![
2198 LanguageModelRequestMessage {
2199 role: Role::User,
2200 content: vec!["Call the echo tool!".into()],
2201 cache: false
2202 },
2203 LanguageModelRequestMessage {
2204 role: Role::Assistant,
2205 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2206 cache: false
2207 },
2208 LanguageModelRequestMessage {
2209 role: Role::User,
2210 content: vec![language_model::MessageContent::ToolResult(
2211 LanguageModelToolResult {
2212 tool_use_id: tool_use_1.id.clone(),
2213 tool_name: tool_use_1.name.clone(),
2214 is_error: false,
2215 content: "test".into(),
2216 output: Some("test".into())
2217 }
2218 )],
2219 cache: true
2220 },
2221 ]
2222 );
2223
2224 fake_model.send_last_completion_stream_text_chunk("Done");
2225 fake_model.end_last_completion_stream();
2226 cx.run_until_parked();
2227 events.collect::<Vec<_>>().await;
2228 thread.read_with(cx, |thread, _cx| {
2229 assert_eq!(
2230 thread.last_message(),
2231 Some(Message::Agent(AgentMessage {
2232 content: vec![AgentMessageContent::Text("Done".into())],
2233 tool_results: IndexMap::default()
2234 }))
2235 );
2236 })
2237}
2238
2239#[gpui::test]
2240async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2241 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2242 let fake_model = model.as_fake();
2243
2244 let mut events = thread
2245 .update(cx, |thread, cx| {
2246 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2247 thread.send(UserMessageId::new(), ["Hello!"], cx)
2248 })
2249 .unwrap();
2250 cx.run_until_parked();
2251
2252 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2253 fake_model.send_last_completion_stream_error(
2254 LanguageModelCompletionError::ServerOverloaded {
2255 provider: LanguageModelProviderName::new("Anthropic"),
2256 retry_after: Some(Duration::from_secs(3)),
2257 },
2258 );
2259 fake_model.end_last_completion_stream();
2260 cx.executor().advance_clock(Duration::from_secs(3));
2261 cx.run_until_parked();
2262 }
2263
2264 let mut errors = Vec::new();
2265 let mut retry_events = Vec::new();
2266 while let Some(event) = events.next().await {
2267 match event {
2268 Ok(ThreadEvent::Retry(retry_status)) => {
2269 retry_events.push(retry_status);
2270 }
2271 Ok(ThreadEvent::Stop(..)) => break,
2272 Err(error) => errors.push(error),
2273 _ => {}
2274 }
2275 }
2276
2277 assert_eq!(
2278 retry_events.len(),
2279 crate::thread::MAX_RETRY_ATTEMPTS as usize
2280 );
2281 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2282 assert_eq!(retry_events[i].attempt, i + 1);
2283 }
2284 assert_eq!(errors.len(), 1);
2285 let error = errors[0]
2286 .downcast_ref::<LanguageModelCompletionError>()
2287 .unwrap();
2288 assert!(matches!(
2289 error,
2290 LanguageModelCompletionError::ServerOverloaded { .. }
2291 ));
2292}
2293
2294/// Filters out the stop events for asserting against in tests
2295fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2296 result_events
2297 .into_iter()
2298 .filter_map(|event| match event.unwrap() {
2299 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2300 _ => None,
2301 })
2302 .collect()
2303}
2304
2305struct ThreadTest {
2306 model: Arc<dyn LanguageModel>,
2307 thread: Entity<Thread>,
2308 project_context: Entity<ProjectContext>,
2309 context_server_store: Entity<ContextServerStore>,
2310 fs: Arc<FakeFs>,
2311}
2312
2313enum TestModel {
2314 Sonnet4,
2315 Fake,
2316}
2317
2318impl TestModel {
2319 fn id(&self) -> LanguageModelId {
2320 match self {
2321 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2322 TestModel::Fake => unreachable!(),
2323 }
2324 }
2325}
2326
2327async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2328 cx.executor().allow_parking();
2329
2330 let fs = FakeFs::new(cx.background_executor.clone());
2331 fs.create_dir(paths::settings_file().parent().unwrap())
2332 .await
2333 .unwrap();
2334 fs.insert_file(
2335 paths::settings_file(),
2336 json!({
2337 "agent": {
2338 "default_profile": "test-profile",
2339 "profiles": {
2340 "test-profile": {
2341 "name": "Test Profile",
2342 "tools": {
2343 EchoTool::name(): true,
2344 DelayTool::name(): true,
2345 WordListTool::name(): true,
2346 ToolRequiringPermission::name(): true,
2347 InfiniteTool::name(): true,
2348 ThinkingTool::name(): true,
2349 }
2350 }
2351 }
2352 }
2353 })
2354 .to_string()
2355 .into_bytes(),
2356 )
2357 .await;
2358
2359 cx.update(|cx| {
2360 settings::init(cx);
2361 Project::init_settings(cx);
2362 agent_settings::init(cx);
2363
2364 match model {
2365 TestModel::Fake => {}
2366 TestModel::Sonnet4 => {
2367 gpui_tokio::init(cx);
2368 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2369 cx.set_http_client(Arc::new(http_client));
2370 client::init_settings(cx);
2371 let client = Client::production(cx);
2372 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2373 language_model::init(client.clone(), cx);
2374 language_models::init(user_store, client.clone(), cx);
2375 }
2376 };
2377
2378 watch_settings(fs.clone(), cx);
2379 });
2380
2381 let templates = Templates::new();
2382
2383 fs.insert_tree(path!("/test"), json!({})).await;
2384 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2385
2386 let model = cx
2387 .update(|cx| {
2388 if let TestModel::Fake = model {
2389 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2390 } else {
2391 let model_id = model.id();
2392 let models = LanguageModelRegistry::read_global(cx);
2393 let model = models
2394 .available_models(cx)
2395 .find(|model| model.id() == model_id)
2396 .unwrap();
2397
2398 let provider = models.provider(&model.provider_id()).unwrap();
2399 let authenticated = provider.authenticate(cx);
2400
2401 cx.spawn(async move |_cx| {
2402 authenticated.await.unwrap();
2403 model
2404 })
2405 }
2406 })
2407 .await;
2408
2409 let project_context = cx.new(|_cx| ProjectContext::default());
2410 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2411 let context_server_registry =
2412 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2413 let thread = cx.new(|cx| {
2414 Thread::new(
2415 project,
2416 project_context.clone(),
2417 context_server_registry,
2418 templates,
2419 Some(model.clone()),
2420 cx,
2421 )
2422 });
2423 ThreadTest {
2424 model,
2425 thread,
2426 project_context,
2427 context_server_store,
2428 fs,
2429 }
2430}
2431
2432#[cfg(test)]
2433#[ctor::ctor]
2434fn init_logger() {
2435 if std::env::var("RUST_LOG").is_ok() {
2436 env_logger::init();
2437 }
2438}
2439
2440fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2441 let fs = fs.clone();
2442 cx.spawn({
2443 async move |cx| {
2444 let mut new_settings_content_rx = settings::watch_config_file(
2445 cx.background_executor(),
2446 fs,
2447 paths::settings_file().clone(),
2448 );
2449
2450 while let Some(new_settings_content) = new_settings_content_rx.next().await {
2451 cx.update(|cx| {
2452 SettingsStore::update_global(cx, |settings, cx| {
2453 settings.set_user_settings(&new_settings_content, cx)
2454 })
2455 })
2456 .ok();
2457 }
2458 }
2459 })
2460 .detach();
2461}
2462
2463fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2464 completion
2465 .tools
2466 .iter()
2467 .map(|tool| tool.name.clone())
2468 .collect()
2469}
2470
2471fn setup_context_server(
2472 name: &'static str,
2473 tools: Vec<context_server::types::Tool>,
2474 context_server_store: &Entity<ContextServerStore>,
2475 cx: &mut TestAppContext,
2476) -> mpsc::UnboundedReceiver<(
2477 context_server::types::CallToolParams,
2478 oneshot::Sender<context_server::types::CallToolResponse>,
2479)> {
2480 cx.update(|cx| {
2481 let mut settings = ProjectSettings::get_global(cx).clone();
2482 settings.context_servers.insert(
2483 name.into(),
2484 project::project_settings::ContextServerSettings::Custom {
2485 enabled: true,
2486 command: ContextServerCommand {
2487 path: "somebinary".into(),
2488 args: Vec::new(),
2489 env: None,
2490 timeout: None,
2491 },
2492 },
2493 );
2494 ProjectSettings::override_global(settings, cx);
2495 });
2496
2497 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2498 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2499 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2500 context_server::types::InitializeResponse {
2501 protocol_version: context_server::types::ProtocolVersion(
2502 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2503 ),
2504 server_info: context_server::types::Implementation {
2505 name: name.into(),
2506 version: "1.0.0".to_string(),
2507 },
2508 capabilities: context_server::types::ServerCapabilities {
2509 tools: Some(context_server::types::ToolsCapabilities {
2510 list_changed: Some(true),
2511 }),
2512 ..Default::default()
2513 },
2514 meta: None,
2515 }
2516 })
2517 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2518 let tools = tools.clone();
2519 async move {
2520 context_server::types::ListToolsResponse {
2521 tools,
2522 next_cursor: None,
2523 meta: None,
2524 }
2525 }
2526 })
2527 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2528 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2529 async move {
2530 let (response_tx, response_rx) = oneshot::channel();
2531 mcp_tool_calls_tx
2532 .unbounded_send((params, response_tx))
2533 .unwrap();
2534 response_rx.await.unwrap()
2535 }
2536 });
2537 context_server_store.update(cx, |store, cx| {
2538 store.start_server(
2539 Arc::new(ContextServer::new(
2540 ContextServerId(name.into()),
2541 Arc::new(fake_transport),
2542 )),
2543 cx,
2544 );
2545 });
2546 cx.run_until_parked();
2547 mcp_tool_calls_rx
2548}