1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, new_prompt_id};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use fs::{FakeFs, Fs};
11use futures::{
12 StreamExt,
13 channel::{
14 mpsc::{self, UnboundedReceiver},
15 oneshot,
16 },
17};
18use gpui::{
19 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
20};
21use indoc::indoc;
22use language_model::{
23 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
24 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
25 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
26 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
27};
28use pretty_assertions::assert_eq;
29use project::{
30 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
31};
32use prompt_store::ProjectContext;
33use reqwest_client::ReqwestClient;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use serde_json::json;
37use settings::{Settings, SettingsStore};
38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
39use util::path;
40
41mod test_tools;
42use test_tools::*;
43
44#[gpui::test]
45async fn test_echo(cx: &mut TestAppContext) {
46 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
47 let fake_model = model.as_fake();
48
49 let events = thread
50 .update(cx, |thread, cx| {
51 thread.send(new_prompt_id(), ["Testing: Reply with 'Hello'"], cx)
52 })
53 .unwrap();
54 cx.run_until_parked();
55 fake_model.send_last_completion_stream_text_chunk("Hello");
56 fake_model
57 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
58 fake_model.end_last_completion_stream();
59
60 let events = events.collect().await;
61 thread.update(cx, |thread, _cx| {
62 assert_eq!(
63 thread.last_message().unwrap().to_markdown(),
64 indoc! {"
65 ## Assistant
66
67 Hello
68 "}
69 )
70 });
71 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
72}
73
74#[gpui::test]
75async fn test_thinking(cx: &mut TestAppContext) {
76 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
77 let fake_model = model.as_fake();
78
79 let events = thread
80 .update(cx, |thread, cx| {
81 thread.send(
82 new_prompt_id(),
83 [indoc! {"
84 Testing:
85
86 Generate a thinking step where you just think the word 'Think',
87 and have your final answer be 'Hello'
88 "}],
89 cx,
90 )
91 })
92 .unwrap();
93 cx.run_until_parked();
94 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
95 text: "Think".to_string(),
96 signature: None,
97 });
98 fake_model.send_last_completion_stream_text_chunk("Hello");
99 fake_model
100 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
101 fake_model.end_last_completion_stream();
102
103 let events = events.collect().await;
104 thread.update(cx, |thread, _cx| {
105 assert_eq!(
106 thread.last_message().unwrap().to_markdown(),
107 indoc! {"
108 ## Assistant
109
110 <think>Think</think>
111 Hello
112 "}
113 )
114 });
115 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
116}
117
118#[gpui::test]
119async fn test_system_prompt(cx: &mut TestAppContext) {
120 let ThreadTest {
121 model,
122 thread,
123 project_context,
124 ..
125 } = setup(cx, TestModel::Fake).await;
126 let fake_model = model.as_fake();
127
128 project_context.update(cx, |project_context, _cx| {
129 project_context.shell = "test-shell".into()
130 });
131 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
132 thread
133 .update(cx, |thread, cx| thread.send(new_prompt_id(), ["abc"], cx))
134 .unwrap();
135 cx.run_until_parked();
136 let mut pending_completions = fake_model.pending_completions();
137 assert_eq!(
138 pending_completions.len(),
139 1,
140 "unexpected pending completions: {:?}",
141 pending_completions
142 );
143
144 let pending_completion = pending_completions.pop().unwrap();
145 assert_eq!(pending_completion.messages[0].role, Role::System);
146
147 let system_message = &pending_completion.messages[0];
148 let system_prompt = system_message.content[0].to_str().unwrap();
149 assert!(
150 system_prompt.contains("test-shell"),
151 "unexpected system message: {:?}",
152 system_message
153 );
154 assert!(
155 system_prompt.contains("## Fixing Diagnostics"),
156 "unexpected system message: {:?}",
157 system_message
158 );
159}
160
161#[gpui::test]
162async fn test_prompt_caching(cx: &mut TestAppContext) {
163 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
164 let fake_model = model.as_fake();
165
166 // Send initial user message and verify it's cached
167 thread
168 .update(cx, |thread, cx| {
169 thread.send(new_prompt_id(), ["Message 1"], cx)
170 })
171 .unwrap();
172 cx.run_until_parked();
173
174 let completion = fake_model.pending_completions().pop().unwrap();
175 assert_eq!(
176 completion.messages[1..],
177 vec![LanguageModelRequestMessage {
178 role: Role::User,
179 content: vec!["Message 1".into()],
180 cache: true
181 }]
182 );
183 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
184 "Response to Message 1".into(),
185 ));
186 fake_model.end_last_completion_stream();
187 cx.run_until_parked();
188
189 // Send another user message and verify only the latest is cached
190 thread
191 .update(cx, |thread, cx| {
192 thread.send(new_prompt_id(), ["Message 2"], cx)
193 })
194 .unwrap();
195 cx.run_until_parked();
196
197 let completion = fake_model.pending_completions().pop().unwrap();
198 assert_eq!(
199 completion.messages[1..],
200 vec![
201 LanguageModelRequestMessage {
202 role: Role::User,
203 content: vec!["Message 1".into()],
204 cache: false
205 },
206 LanguageModelRequestMessage {
207 role: Role::Assistant,
208 content: vec!["Response to Message 1".into()],
209 cache: false
210 },
211 LanguageModelRequestMessage {
212 role: Role::User,
213 content: vec!["Message 2".into()],
214 cache: true
215 }
216 ]
217 );
218 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
219 "Response to Message 2".into(),
220 ));
221 fake_model.end_last_completion_stream();
222 cx.run_until_parked();
223
224 // Simulate a tool call and verify that the latest tool result is cached
225 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
226 thread
227 .update(cx, |thread, cx| {
228 thread.send(new_prompt_id(), ["Use the echo tool"], cx)
229 })
230 .unwrap();
231 cx.run_until_parked();
232
233 let tool_use = LanguageModelToolUse {
234 id: "tool_1".into(),
235 name: EchoTool::name().into(),
236 raw_input: json!({"text": "test"}).to_string(),
237 input: json!({"text": "test"}),
238 is_input_complete: true,
239 };
240 fake_model
241 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
242 fake_model.end_last_completion_stream();
243 cx.run_until_parked();
244
245 let completion = fake_model.pending_completions().pop().unwrap();
246 let tool_result = LanguageModelToolResult {
247 tool_use_id: "tool_1".into(),
248 tool_name: EchoTool::name().into(),
249 is_error: false,
250 content: "test".into(),
251 output: Some("test".into()),
252 };
253 assert_eq!(
254 completion.messages[1..],
255 vec![
256 LanguageModelRequestMessage {
257 role: Role::User,
258 content: vec!["Message 1".into()],
259 cache: false
260 },
261 LanguageModelRequestMessage {
262 role: Role::Assistant,
263 content: vec!["Response to Message 1".into()],
264 cache: false
265 },
266 LanguageModelRequestMessage {
267 role: Role::User,
268 content: vec!["Message 2".into()],
269 cache: false
270 },
271 LanguageModelRequestMessage {
272 role: Role::Assistant,
273 content: vec!["Response to Message 2".into()],
274 cache: false
275 },
276 LanguageModelRequestMessage {
277 role: Role::User,
278 content: vec!["Use the echo tool".into()],
279 cache: false
280 },
281 LanguageModelRequestMessage {
282 role: Role::Assistant,
283 content: vec![MessageContent::ToolUse(tool_use)],
284 cache: false
285 },
286 LanguageModelRequestMessage {
287 role: Role::User,
288 content: vec![MessageContent::ToolResult(tool_result)],
289 cache: true
290 }
291 ]
292 );
293}
294
295#[gpui::test]
296#[cfg_attr(not(feature = "e2e"), ignore)]
297async fn test_basic_tool_calls(cx: &mut TestAppContext) {
298 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
299
300 // Test a tool call that's likely to complete *before* streaming stops.
301 let events = thread
302 .update(cx, |thread, cx| {
303 thread.add_tool(EchoTool);
304 thread.send(
305 new_prompt_id(),
306 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
307 cx,
308 )
309 })
310 .unwrap()
311 .collect()
312 .await;
313 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
314
315 // Test a tool calls that's likely to complete *after* streaming stops.
316 let events = thread
317 .update(cx, |thread, cx| {
318 thread.remove_tool(&EchoTool::name());
319 thread.add_tool(DelayTool);
320 thread.send(
321 new_prompt_id(),
322 [
323 "Now call the delay tool with 200ms.",
324 "When the timer goes off, then you echo the output of the tool.",
325 ],
326 cx,
327 )
328 })
329 .unwrap()
330 .collect()
331 .await;
332 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
333 thread.update(cx, |thread, _cx| {
334 assert!(
335 thread
336 .last_message()
337 .unwrap()
338 .as_agent_message()
339 .unwrap()
340 .content
341 .iter()
342 .any(|content| {
343 if let AgentMessageContent::Text(text) = content {
344 text.contains("Ding")
345 } else {
346 false
347 }
348 }),
349 "{}",
350 thread.to_markdown()
351 );
352 });
353}
354
355#[gpui::test]
356#[cfg_attr(not(feature = "e2e"), ignore)]
357async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
358 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
359
360 // Test a tool call that's likely to complete *before* streaming stops.
361 let mut events = thread
362 .update(cx, |thread, cx| {
363 thread.add_tool(WordListTool);
364 thread.send(new_prompt_id(), ["Test the word_list tool."], cx)
365 })
366 .unwrap();
367
368 let mut saw_partial_tool_use = false;
369 while let Some(event) = events.next().await {
370 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
371 thread.update(cx, |thread, _cx| {
372 // Look for a tool use in the thread's last message
373 let message = thread.last_message().unwrap();
374 let agent_message = message.as_agent_message().unwrap();
375 let last_content = agent_message.content.last().unwrap();
376 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
377 assert_eq!(last_tool_use.name.as_ref(), "word_list");
378 if tool_call.status == acp::ToolCallStatus::Pending {
379 if !last_tool_use.is_input_complete
380 && last_tool_use.input.get("g").is_none()
381 {
382 saw_partial_tool_use = true;
383 }
384 } else {
385 last_tool_use
386 .input
387 .get("a")
388 .expect("'a' has streamed because input is now complete");
389 last_tool_use
390 .input
391 .get("g")
392 .expect("'g' has streamed because input is now complete");
393 }
394 } else {
395 panic!("last content should be a tool use");
396 }
397 });
398 }
399 }
400
401 assert!(
402 saw_partial_tool_use,
403 "should see at least one partially streamed tool use in the history"
404 );
405}
406
407#[gpui::test]
408async fn test_tool_authorization(cx: &mut TestAppContext) {
409 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
410 let fake_model = model.as_fake();
411
412 let mut events = thread
413 .update(cx, |thread, cx| {
414 thread.add_tool(ToolRequiringPermission);
415 thread.send(new_prompt_id(), ["abc"], cx)
416 })
417 .unwrap();
418 cx.run_until_parked();
419 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
420 LanguageModelToolUse {
421 id: "tool_id_1".into(),
422 name: ToolRequiringPermission::name().into(),
423 raw_input: "{}".into(),
424 input: json!({}),
425 is_input_complete: true,
426 },
427 ));
428 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
429 LanguageModelToolUse {
430 id: "tool_id_2".into(),
431 name: ToolRequiringPermission::name().into(),
432 raw_input: "{}".into(),
433 input: json!({}),
434 is_input_complete: true,
435 },
436 ));
437 fake_model.end_last_completion_stream();
438 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
439 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
440
441 // Approve the first
442 tool_call_auth_1
443 .response
444 .send(tool_call_auth_1.options[1].id.clone())
445 .unwrap();
446 cx.run_until_parked();
447
448 // Reject the second
449 tool_call_auth_2
450 .response
451 .send(tool_call_auth_1.options[2].id.clone())
452 .unwrap();
453 cx.run_until_parked();
454
455 let completion = fake_model.pending_completions().pop().unwrap();
456 let message = completion.messages.last().unwrap();
457 assert_eq!(
458 message.content,
459 vec![
460 language_model::MessageContent::ToolResult(LanguageModelToolResult {
461 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
462 tool_name: ToolRequiringPermission::name().into(),
463 is_error: false,
464 content: "Allowed".into(),
465 output: Some("Allowed".into())
466 }),
467 language_model::MessageContent::ToolResult(LanguageModelToolResult {
468 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
469 tool_name: ToolRequiringPermission::name().into(),
470 is_error: true,
471 content: "Permission to run tool denied by user".into(),
472 output: None
473 })
474 ]
475 );
476
477 // Simulate yet another tool call.
478 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
479 LanguageModelToolUse {
480 id: "tool_id_3".into(),
481 name: ToolRequiringPermission::name().into(),
482 raw_input: "{}".into(),
483 input: json!({}),
484 is_input_complete: true,
485 },
486 ));
487 fake_model.end_last_completion_stream();
488
489 // Respond by always allowing tools.
490 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
491 tool_call_auth_3
492 .response
493 .send(tool_call_auth_3.options[0].id.clone())
494 .unwrap();
495 cx.run_until_parked();
496 let completion = fake_model.pending_completions().pop().unwrap();
497 let message = completion.messages.last().unwrap();
498 assert_eq!(
499 message.content,
500 vec![language_model::MessageContent::ToolResult(
501 LanguageModelToolResult {
502 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
503 tool_name: ToolRequiringPermission::name().into(),
504 is_error: false,
505 content: "Allowed".into(),
506 output: Some("Allowed".into())
507 }
508 )]
509 );
510
511 // Simulate a final tool call, ensuring we don't trigger authorization.
512 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
513 LanguageModelToolUse {
514 id: "tool_id_4".into(),
515 name: ToolRequiringPermission::name().into(),
516 raw_input: "{}".into(),
517 input: json!({}),
518 is_input_complete: true,
519 },
520 ));
521 fake_model.end_last_completion_stream();
522 cx.run_until_parked();
523 let completion = fake_model.pending_completions().pop().unwrap();
524 let message = completion.messages.last().unwrap();
525 assert_eq!(
526 message.content,
527 vec![language_model::MessageContent::ToolResult(
528 LanguageModelToolResult {
529 tool_use_id: "tool_id_4".into(),
530 tool_name: ToolRequiringPermission::name().into(),
531 is_error: false,
532 content: "Allowed".into(),
533 output: Some("Allowed".into())
534 }
535 )]
536 );
537}
538
539#[gpui::test]
540async fn test_tool_hallucination(cx: &mut TestAppContext) {
541 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
542 let fake_model = model.as_fake();
543
544 let mut events = thread
545 .update(cx, |thread, cx| thread.send(new_prompt_id(), ["abc"], cx))
546 .unwrap();
547 cx.run_until_parked();
548 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
549 LanguageModelToolUse {
550 id: "tool_id_1".into(),
551 name: "nonexistent_tool".into(),
552 raw_input: "{}".into(),
553 input: json!({}),
554 is_input_complete: true,
555 },
556 ));
557 fake_model.end_last_completion_stream();
558
559 let tool_call = expect_tool_call(&mut events).await;
560 assert_eq!(tool_call.title, "nonexistent_tool");
561 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
562 let update = expect_tool_call_update_fields(&mut events).await;
563 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
564}
565
566#[gpui::test]
567async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
568 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
569 let fake_model = model.as_fake();
570
571 let events = thread
572 .update(cx, |thread, cx| {
573 thread.add_tool(EchoTool);
574 thread.send(new_prompt_id(), ["abc"], cx)
575 })
576 .unwrap();
577 cx.run_until_parked();
578 let tool_use = LanguageModelToolUse {
579 id: "tool_id_1".into(),
580 name: EchoTool::name().into(),
581 raw_input: "{}".into(),
582 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
583 is_input_complete: true,
584 };
585 fake_model
586 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
587 fake_model.end_last_completion_stream();
588
589 cx.run_until_parked();
590 let completion = fake_model.pending_completions().pop().unwrap();
591 let tool_result = LanguageModelToolResult {
592 tool_use_id: "tool_id_1".into(),
593 tool_name: EchoTool::name().into(),
594 is_error: false,
595 content: "def".into(),
596 output: Some("def".into()),
597 };
598 assert_eq!(
599 completion.messages[1..],
600 vec![
601 LanguageModelRequestMessage {
602 role: Role::User,
603 content: vec!["abc".into()],
604 cache: false
605 },
606 LanguageModelRequestMessage {
607 role: Role::Assistant,
608 content: vec![MessageContent::ToolUse(tool_use.clone())],
609 cache: false
610 },
611 LanguageModelRequestMessage {
612 role: Role::User,
613 content: vec![MessageContent::ToolResult(tool_result.clone())],
614 cache: true
615 },
616 ]
617 );
618
619 // Simulate reaching tool use limit.
620 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
621 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
622 ));
623 fake_model.end_last_completion_stream();
624 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
625 assert!(
626 last_event
627 .unwrap_err()
628 .is::<language_model::ToolUseLimitReachedError>()
629 );
630
631 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
632 cx.run_until_parked();
633 let completion = fake_model.pending_completions().pop().unwrap();
634 assert_eq!(
635 completion.messages[1..],
636 vec![
637 LanguageModelRequestMessage {
638 role: Role::User,
639 content: vec!["abc".into()],
640 cache: false
641 },
642 LanguageModelRequestMessage {
643 role: Role::Assistant,
644 content: vec![MessageContent::ToolUse(tool_use)],
645 cache: false
646 },
647 LanguageModelRequestMessage {
648 role: Role::User,
649 content: vec![MessageContent::ToolResult(tool_result)],
650 cache: false
651 },
652 LanguageModelRequestMessage {
653 role: Role::User,
654 content: vec!["Continue where you left off".into()],
655 cache: true
656 }
657 ]
658 );
659
660 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
661 fake_model.end_last_completion_stream();
662 events.collect::<Vec<_>>().await;
663 thread.read_with(cx, |thread, _cx| {
664 assert_eq!(
665 thread.last_message().unwrap().to_markdown(),
666 indoc! {"
667 ## Assistant
668
669 Done
670 "}
671 )
672 });
673}
674
675#[gpui::test]
676async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
677 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
678 let fake_model = model.as_fake();
679
680 let events = thread
681 .update(cx, |thread, cx| {
682 thread.add_tool(EchoTool);
683 thread.send(new_prompt_id(), ["abc"], cx)
684 })
685 .unwrap();
686 cx.run_until_parked();
687
688 let tool_use = LanguageModelToolUse {
689 id: "tool_id_1".into(),
690 name: EchoTool::name().into(),
691 raw_input: "{}".into(),
692 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
693 is_input_complete: true,
694 };
695 let tool_result = LanguageModelToolResult {
696 tool_use_id: "tool_id_1".into(),
697 tool_name: EchoTool::name().into(),
698 is_error: false,
699 content: "def".into(),
700 output: Some("def".into()),
701 };
702 fake_model
703 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
704 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
705 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
706 ));
707 fake_model.end_last_completion_stream();
708 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
709 assert!(
710 last_event
711 .unwrap_err()
712 .is::<language_model::ToolUseLimitReachedError>()
713 );
714
715 thread
716 .update(cx, |thread, cx| {
717 thread.send(new_prompt_id(), vec!["ghi"], cx)
718 })
719 .unwrap();
720 cx.run_until_parked();
721 let completion = fake_model.pending_completions().pop().unwrap();
722 assert_eq!(
723 completion.messages[1..],
724 vec![
725 LanguageModelRequestMessage {
726 role: Role::User,
727 content: vec!["abc".into()],
728 cache: false
729 },
730 LanguageModelRequestMessage {
731 role: Role::Assistant,
732 content: vec![MessageContent::ToolUse(tool_use)],
733 cache: false
734 },
735 LanguageModelRequestMessage {
736 role: Role::User,
737 content: vec![MessageContent::ToolResult(tool_result)],
738 cache: false
739 },
740 LanguageModelRequestMessage {
741 role: Role::User,
742 content: vec!["ghi".into()],
743 cache: true
744 }
745 ]
746 );
747}
748
749async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
750 let event = events
751 .next()
752 .await
753 .expect("no tool call authorization event received")
754 .unwrap();
755 match event {
756 ThreadEvent::ToolCall(tool_call) => tool_call,
757 event => {
758 panic!("Unexpected event {event:?}");
759 }
760 }
761}
762
763async fn expect_tool_call_update_fields(
764 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
765) -> acp::ToolCallUpdate {
766 let event = events
767 .next()
768 .await
769 .expect("no tool call authorization event received")
770 .unwrap();
771 match event {
772 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
773 event => {
774 panic!("Unexpected event {event:?}");
775 }
776 }
777}
778
779async fn next_tool_call_authorization(
780 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
781) -> ToolCallAuthorization {
782 loop {
783 let event = events
784 .next()
785 .await
786 .expect("no tool call authorization event received")
787 .unwrap();
788 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
789 let permission_kinds = tool_call_authorization
790 .options
791 .iter()
792 .map(|o| o.kind)
793 .collect::<Vec<_>>();
794 assert_eq!(
795 permission_kinds,
796 vec![
797 acp::PermissionOptionKind::AllowAlways,
798 acp::PermissionOptionKind::AllowOnce,
799 acp::PermissionOptionKind::RejectOnce,
800 ]
801 );
802 return tool_call_authorization;
803 }
804 }
805}
806
807#[gpui::test]
808#[cfg_attr(not(feature = "e2e"), ignore)]
809async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
810 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
811
812 // Test concurrent tool calls with different delay times
813 let events = thread
814 .update(cx, |thread, cx| {
815 thread.add_tool(DelayTool);
816 thread.send(
817 new_prompt_id(),
818 [
819 "Call the delay tool twice in the same message.",
820 "Once with 100ms. Once with 300ms.",
821 "When both timers are complete, describe the outputs.",
822 ],
823 cx,
824 )
825 })
826 .unwrap()
827 .collect()
828 .await;
829
830 let stop_reasons = stop_events(events);
831 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
832
833 thread.update(cx, |thread, _cx| {
834 let last_message = thread.last_message().unwrap();
835 let agent_message = last_message.as_agent_message().unwrap();
836 let text = agent_message
837 .content
838 .iter()
839 .filter_map(|content| {
840 if let AgentMessageContent::Text(text) = content {
841 Some(text.as_str())
842 } else {
843 None
844 }
845 })
846 .collect::<String>();
847
848 assert!(text.contains("Ding"));
849 });
850}
851
852#[gpui::test]
853async fn test_profiles(cx: &mut TestAppContext) {
854 let ThreadTest {
855 model, thread, fs, ..
856 } = setup(cx, TestModel::Fake).await;
857 let fake_model = model.as_fake();
858
859 thread.update(cx, |thread, _cx| {
860 thread.add_tool(DelayTool);
861 thread.add_tool(EchoTool);
862 thread.add_tool(InfiniteTool);
863 });
864
865 // Override profiles and wait for settings to be loaded.
866 fs.insert_file(
867 paths::settings_file(),
868 json!({
869 "agent": {
870 "profiles": {
871 "test-1": {
872 "name": "Test Profile 1",
873 "tools": {
874 EchoTool::name(): true,
875 DelayTool::name(): true,
876 }
877 },
878 "test-2": {
879 "name": "Test Profile 2",
880 "tools": {
881 InfiniteTool::name(): true,
882 }
883 }
884 }
885 }
886 })
887 .to_string()
888 .into_bytes(),
889 )
890 .await;
891 cx.run_until_parked();
892
893 // Test that test-1 profile (default) has echo and delay tools
894 thread
895 .update(cx, |thread, cx| {
896 thread.set_profile(AgentProfileId("test-1".into()));
897 thread.send(new_prompt_id(), ["test"], cx)
898 })
899 .unwrap();
900 cx.run_until_parked();
901
902 let mut pending_completions = fake_model.pending_completions();
903 assert_eq!(pending_completions.len(), 1);
904 let completion = pending_completions.pop().unwrap();
905 let tool_names: Vec<String> = completion
906 .tools
907 .iter()
908 .map(|tool| tool.name.clone())
909 .collect();
910 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
911 fake_model.end_last_completion_stream();
912
913 // Switch to test-2 profile, and verify that it has only the infinite tool.
914 thread
915 .update(cx, |thread, cx| {
916 thread.set_profile(AgentProfileId("test-2".into()));
917 thread.send(new_prompt_id(), ["test2"], cx)
918 })
919 .unwrap();
920 cx.run_until_parked();
921 let mut pending_completions = fake_model.pending_completions();
922 assert_eq!(pending_completions.len(), 1);
923 let completion = pending_completions.pop().unwrap();
924 let tool_names: Vec<String> = completion
925 .tools
926 .iter()
927 .map(|tool| tool.name.clone())
928 .collect();
929 assert_eq!(tool_names, vec![InfiniteTool::name()]);
930}
931
932#[gpui::test]
933async fn test_mcp_tools(cx: &mut TestAppContext) {
934 let ThreadTest {
935 model,
936 thread,
937 context_server_store,
938 fs,
939 ..
940 } = setup(cx, TestModel::Fake).await;
941 let fake_model = model.as_fake();
942
943 // Override profiles and wait for settings to be loaded.
944 fs.insert_file(
945 paths::settings_file(),
946 json!({
947 "agent": {
948 "profiles": {
949 "test": {
950 "name": "Test Profile",
951 "enable_all_context_servers": true,
952 "tools": {
953 EchoTool::name(): true,
954 }
955 },
956 }
957 }
958 })
959 .to_string()
960 .into_bytes(),
961 )
962 .await;
963 cx.run_until_parked();
964 thread.update(cx, |thread, _| {
965 thread.set_profile(AgentProfileId("test".into()))
966 });
967
968 let mut mcp_tool_calls = setup_context_server(
969 "test_server",
970 vec![context_server::types::Tool {
971 name: "echo".into(),
972 description: None,
973 input_schema: serde_json::to_value(
974 EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
975 )
976 .unwrap(),
977 output_schema: None,
978 annotations: None,
979 }],
980 &context_server_store,
981 cx,
982 );
983
984 let events = thread.update(cx, |thread, cx| {
985 thread.send(new_prompt_id(), ["Hey"], cx).unwrap()
986 });
987 cx.run_until_parked();
988
989 // Simulate the model calling the MCP tool.
990 let completion = fake_model.pending_completions().pop().unwrap();
991 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
992 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
993 LanguageModelToolUse {
994 id: "tool_1".into(),
995 name: "echo".into(),
996 raw_input: json!({"text": "test"}).to_string(),
997 input: json!({"text": "test"}),
998 is_input_complete: true,
999 },
1000 ));
1001 fake_model.end_last_completion_stream();
1002 cx.run_until_parked();
1003
1004 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1005 assert_eq!(tool_call_params.name, "echo");
1006 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1007 tool_call_response
1008 .send(context_server::types::CallToolResponse {
1009 content: vec![context_server::types::ToolResponseContent::Text {
1010 text: "test".into(),
1011 }],
1012 is_error: None,
1013 meta: None,
1014 structured_content: None,
1015 })
1016 .unwrap();
1017 cx.run_until_parked();
1018
1019 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1020 fake_model.send_last_completion_stream_text_chunk("Done!");
1021 fake_model.end_last_completion_stream();
1022 events.collect::<Vec<_>>().await;
1023
1024 // Send again after adding the echo tool, ensuring the name collision is resolved.
1025 let events = thread.update(cx, |thread, cx| {
1026 thread.add_tool(EchoTool);
1027 thread.send(new_prompt_id(), ["Go"], cx).unwrap()
1028 });
1029 cx.run_until_parked();
1030 let completion = fake_model.pending_completions().pop().unwrap();
1031 assert_eq!(
1032 tool_names_for_completion(&completion),
1033 vec!["echo", "test_server_echo"]
1034 );
1035 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1036 LanguageModelToolUse {
1037 id: "tool_2".into(),
1038 name: "test_server_echo".into(),
1039 raw_input: json!({"text": "mcp"}).to_string(),
1040 input: json!({"text": "mcp"}),
1041 is_input_complete: true,
1042 },
1043 ));
1044 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1045 LanguageModelToolUse {
1046 id: "tool_3".into(),
1047 name: "echo".into(),
1048 raw_input: json!({"text": "native"}).to_string(),
1049 input: json!({"text": "native"}),
1050 is_input_complete: true,
1051 },
1052 ));
1053 fake_model.end_last_completion_stream();
1054 cx.run_until_parked();
1055
1056 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1057 assert_eq!(tool_call_params.name, "echo");
1058 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1059 tool_call_response
1060 .send(context_server::types::CallToolResponse {
1061 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1062 is_error: None,
1063 meta: None,
1064 structured_content: None,
1065 })
1066 .unwrap();
1067 cx.run_until_parked();
1068
1069 // Ensure the tool results were inserted with the correct names.
1070 let completion = fake_model.pending_completions().pop().unwrap();
1071 assert_eq!(
1072 completion.messages.last().unwrap().content,
1073 vec![
1074 MessageContent::ToolResult(LanguageModelToolResult {
1075 tool_use_id: "tool_3".into(),
1076 tool_name: "echo".into(),
1077 is_error: false,
1078 content: "native".into(),
1079 output: Some("native".into()),
1080 },),
1081 MessageContent::ToolResult(LanguageModelToolResult {
1082 tool_use_id: "tool_2".into(),
1083 tool_name: "test_server_echo".into(),
1084 is_error: false,
1085 content: "mcp".into(),
1086 output: Some("mcp".into()),
1087 },),
1088 ]
1089 );
1090 fake_model.end_last_completion_stream();
1091 events.collect::<Vec<_>>().await;
1092}
1093
1094#[gpui::test]
1095async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1096 let ThreadTest {
1097 model,
1098 thread,
1099 context_server_store,
1100 fs,
1101 ..
1102 } = setup(cx, TestModel::Fake).await;
1103 let fake_model = model.as_fake();
1104
1105 // Set up a profile with all tools enabled
1106 fs.insert_file(
1107 paths::settings_file(),
1108 json!({
1109 "agent": {
1110 "profiles": {
1111 "test": {
1112 "name": "Test Profile",
1113 "enable_all_context_servers": true,
1114 "tools": {
1115 EchoTool::name(): true,
1116 DelayTool::name(): true,
1117 WordListTool::name(): true,
1118 ToolRequiringPermission::name(): true,
1119 InfiniteTool::name(): true,
1120 }
1121 },
1122 }
1123 }
1124 })
1125 .to_string()
1126 .into_bytes(),
1127 )
1128 .await;
1129 cx.run_until_parked();
1130
1131 thread.update(cx, |thread, _| {
1132 thread.set_profile(AgentProfileId("test".into()));
1133 thread.add_tool(EchoTool);
1134 thread.add_tool(DelayTool);
1135 thread.add_tool(WordListTool);
1136 thread.add_tool(ToolRequiringPermission);
1137 thread.add_tool(InfiniteTool);
1138 });
1139
1140 // Set up multiple context servers with some overlapping tool names
1141 let _server1_calls = setup_context_server(
1142 "xxx",
1143 vec![
1144 context_server::types::Tool {
1145 name: "echo".into(), // Conflicts with native EchoTool
1146 description: None,
1147 input_schema: serde_json::to_value(
1148 EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1149 )
1150 .unwrap(),
1151 output_schema: None,
1152 annotations: None,
1153 },
1154 context_server::types::Tool {
1155 name: "unique_tool_1".into(),
1156 description: None,
1157 input_schema: json!({"type": "object", "properties": {}}),
1158 output_schema: None,
1159 annotations: None,
1160 },
1161 ],
1162 &context_server_store,
1163 cx,
1164 );
1165
1166 let _server2_calls = setup_context_server(
1167 "yyy",
1168 vec![
1169 context_server::types::Tool {
1170 name: "echo".into(), // Also conflicts with native EchoTool
1171 description: None,
1172 input_schema: serde_json::to_value(
1173 EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1174 )
1175 .unwrap(),
1176 output_schema: None,
1177 annotations: None,
1178 },
1179 context_server::types::Tool {
1180 name: "unique_tool_2".into(),
1181 description: None,
1182 input_schema: json!({"type": "object", "properties": {}}),
1183 output_schema: None,
1184 annotations: None,
1185 },
1186 context_server::types::Tool {
1187 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1188 description: None,
1189 input_schema: json!({"type": "object", "properties": {}}),
1190 output_schema: None,
1191 annotations: None,
1192 },
1193 context_server::types::Tool {
1194 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1195 description: None,
1196 input_schema: json!({"type": "object", "properties": {}}),
1197 output_schema: None,
1198 annotations: None,
1199 },
1200 ],
1201 &context_server_store,
1202 cx,
1203 );
1204 let _server3_calls = setup_context_server(
1205 "zzz",
1206 vec![
1207 context_server::types::Tool {
1208 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1209 description: None,
1210 input_schema: json!({"type": "object", "properties": {}}),
1211 output_schema: None,
1212 annotations: None,
1213 },
1214 context_server::types::Tool {
1215 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1216 description: None,
1217 input_schema: json!({"type": "object", "properties": {}}),
1218 output_schema: None,
1219 annotations: None,
1220 },
1221 context_server::types::Tool {
1222 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1223 description: None,
1224 input_schema: json!({"type": "object", "properties": {}}),
1225 output_schema: None,
1226 annotations: None,
1227 },
1228 ],
1229 &context_server_store,
1230 cx,
1231 );
1232
1233 thread
1234 .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Go"], cx))
1235 .unwrap();
1236 cx.run_until_parked();
1237 let completion = fake_model.pending_completions().pop().unwrap();
1238 assert_eq!(
1239 tool_names_for_completion(&completion),
1240 vec![
1241 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1242 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1243 "delay",
1244 "echo",
1245 "infinite",
1246 "tool_requiring_permission",
1247 "unique_tool_1",
1248 "unique_tool_2",
1249 "word_list",
1250 "xxx_echo",
1251 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1252 "yyy_echo",
1253 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1254 ]
1255 );
1256}
1257
1258#[gpui::test]
1259#[cfg_attr(not(feature = "e2e"), ignore)]
1260async fn test_cancellation(cx: &mut TestAppContext) {
1261 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1262
1263 let mut events = thread
1264 .update(cx, |thread, cx| {
1265 thread.add_tool(InfiniteTool);
1266 thread.add_tool(EchoTool);
1267 thread.send(
1268 new_prompt_id(),
1269 ["Call the echo tool, then call the infinite tool, then explain their output"],
1270 cx,
1271 )
1272 })
1273 .unwrap();
1274
1275 // Wait until both tools are called.
1276 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1277 let mut echo_id = None;
1278 let mut echo_completed = false;
1279 while let Some(event) = events.next().await {
1280 match event.unwrap() {
1281 ThreadEvent::ToolCall(tool_call) => {
1282 assert_eq!(tool_call.title, expected_tools.remove(0));
1283 if tool_call.title == "Echo" {
1284 echo_id = Some(tool_call.id);
1285 }
1286 }
1287 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1288 acp::ToolCallUpdate {
1289 id,
1290 fields:
1291 acp::ToolCallUpdateFields {
1292 status: Some(acp::ToolCallStatus::Completed),
1293 ..
1294 },
1295 },
1296 )) if Some(&id) == echo_id.as_ref() => {
1297 echo_completed = true;
1298 }
1299 _ => {}
1300 }
1301
1302 if expected_tools.is_empty() && echo_completed {
1303 break;
1304 }
1305 }
1306
1307 // Cancel the current send and ensure that the event stream is closed, even
1308 // if one of the tools is still running.
1309 thread.update(cx, |thread, cx| thread.cancel(cx));
1310 let events = events.collect::<Vec<_>>().await;
1311 let last_event = events.last();
1312 assert!(
1313 matches!(
1314 last_event,
1315 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1316 ),
1317 "unexpected event {last_event:?}"
1318 );
1319
1320 // Ensure we can still send a new message after cancellation.
1321 let events = thread
1322 .update(cx, |thread, cx| {
1323 thread.send(
1324 new_prompt_id(),
1325 ["Testing: reply with 'Hello' then stop."],
1326 cx,
1327 )
1328 })
1329 .unwrap()
1330 .collect::<Vec<_>>()
1331 .await;
1332 thread.update(cx, |thread, _cx| {
1333 let message = thread.last_message().unwrap();
1334 let agent_message = message.as_agent_message().unwrap();
1335 assert_eq!(
1336 agent_message.content,
1337 vec![AgentMessageContent::Text("Hello".to_string())]
1338 );
1339 });
1340 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1341}
1342
1343#[gpui::test]
1344async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1345 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1346 let fake_model = model.as_fake();
1347
1348 let events_1 = thread
1349 .update(cx, |thread, cx| {
1350 thread.send(new_prompt_id(), ["Hello 1"], cx)
1351 })
1352 .unwrap();
1353 cx.run_until_parked();
1354 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1355 cx.run_until_parked();
1356
1357 let events_2 = thread
1358 .update(cx, |thread, cx| {
1359 thread.send(new_prompt_id(), ["Hello 2"], cx)
1360 })
1361 .unwrap();
1362 cx.run_until_parked();
1363 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1364 fake_model
1365 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1366 fake_model.end_last_completion_stream();
1367
1368 let events_1 = events_1.collect::<Vec<_>>().await;
1369 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1370 let events_2 = events_2.collect::<Vec<_>>().await;
1371 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1372}
1373
1374#[gpui::test]
1375async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1376 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1377 let fake_model = model.as_fake();
1378
1379 let events_1 = thread
1380 .update(cx, |thread, cx| {
1381 thread.send(new_prompt_id(), ["Hello 1"], cx)
1382 })
1383 .unwrap();
1384 cx.run_until_parked();
1385 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1386 fake_model
1387 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1388 fake_model.end_last_completion_stream();
1389 let events_1 = events_1.collect::<Vec<_>>().await;
1390
1391 let events_2 = thread
1392 .update(cx, |thread, cx| {
1393 thread.send(new_prompt_id(), ["Hello 2"], cx)
1394 })
1395 .unwrap();
1396 cx.run_until_parked();
1397 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1398 fake_model
1399 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1400 fake_model.end_last_completion_stream();
1401 let events_2 = events_2.collect::<Vec<_>>().await;
1402
1403 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1404 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1405}
1406
1407#[gpui::test]
1408async fn test_refusal(cx: &mut TestAppContext) {
1409 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1410 let fake_model = model.as_fake();
1411
1412 let events = thread
1413 .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Hello"], cx))
1414 .unwrap();
1415 cx.run_until_parked();
1416 thread.read_with(cx, |thread, _| {
1417 assert_eq!(
1418 thread.to_markdown(),
1419 indoc! {"
1420 ## User
1421
1422 Hello
1423 "}
1424 );
1425 });
1426
1427 fake_model.send_last_completion_stream_text_chunk("Hey!");
1428 cx.run_until_parked();
1429 thread.read_with(cx, |thread, _| {
1430 assert_eq!(
1431 thread.to_markdown(),
1432 indoc! {"
1433 ## User
1434
1435 Hello
1436
1437 ## Assistant
1438
1439 Hey!
1440 "}
1441 );
1442 });
1443
1444 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1445 fake_model
1446 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1447 let events = events.collect::<Vec<_>>().await;
1448 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1449 thread.read_with(cx, |thread, _| {
1450 assert_eq!(thread.to_markdown(), "");
1451 });
1452}
1453
1454#[gpui::test]
1455async fn test_truncate_first_message(cx: &mut TestAppContext) {
1456 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1457 let fake_model = model.as_fake();
1458
1459 let message_id = new_prompt_id();
1460 thread
1461 .update(cx, |thread, cx| {
1462 thread.send(message_id.clone(), ["Hello"], cx)
1463 })
1464 .unwrap();
1465 cx.run_until_parked();
1466 thread.read_with(cx, |thread, _| {
1467 assert_eq!(
1468 thread.to_markdown(),
1469 indoc! {"
1470 ## User
1471
1472 Hello
1473 "}
1474 );
1475 assert_eq!(thread.latest_token_usage(), None);
1476 });
1477
1478 fake_model.send_last_completion_stream_text_chunk("Hey!");
1479 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1480 language_model::TokenUsage {
1481 input_tokens: 32_000,
1482 output_tokens: 16_000,
1483 cache_creation_input_tokens: 0,
1484 cache_read_input_tokens: 0,
1485 },
1486 ));
1487 cx.run_until_parked();
1488 thread.read_with(cx, |thread, _| {
1489 assert_eq!(
1490 thread.to_markdown(),
1491 indoc! {"
1492 ## User
1493
1494 Hello
1495
1496 ## Assistant
1497
1498 Hey!
1499 "}
1500 );
1501 assert_eq!(
1502 thread.latest_token_usage(),
1503 Some(acp_thread::TokenUsage {
1504 used_tokens: 32_000 + 16_000,
1505 max_tokens: 1_000_000,
1506 })
1507 );
1508 });
1509
1510 thread
1511 .update(cx, |thread, cx| thread.rewind(message_id, cx))
1512 .unwrap();
1513 cx.run_until_parked();
1514 thread.read_with(cx, |thread, _| {
1515 assert_eq!(thread.to_markdown(), "");
1516 assert_eq!(thread.latest_token_usage(), None);
1517 });
1518
1519 // Ensure we can still send a new message after truncation.
1520 thread
1521 .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Hi"], cx))
1522 .unwrap();
1523 thread.update(cx, |thread, _cx| {
1524 assert_eq!(
1525 thread.to_markdown(),
1526 indoc! {"
1527 ## User
1528
1529 Hi
1530 "}
1531 );
1532 });
1533 cx.run_until_parked();
1534 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1535 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1536 language_model::TokenUsage {
1537 input_tokens: 40_000,
1538 output_tokens: 20_000,
1539 cache_creation_input_tokens: 0,
1540 cache_read_input_tokens: 0,
1541 },
1542 ));
1543 cx.run_until_parked();
1544 thread.read_with(cx, |thread, _| {
1545 assert_eq!(
1546 thread.to_markdown(),
1547 indoc! {"
1548 ## User
1549
1550 Hi
1551
1552 ## Assistant
1553
1554 Ahoy!
1555 "}
1556 );
1557
1558 assert_eq!(
1559 thread.latest_token_usage(),
1560 Some(acp_thread::TokenUsage {
1561 used_tokens: 40_000 + 20_000,
1562 max_tokens: 1_000_000,
1563 })
1564 );
1565 });
1566}
1567
1568#[gpui::test]
1569async fn test_truncate_second_message(cx: &mut TestAppContext) {
1570 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1571 let fake_model = model.as_fake();
1572
1573 thread
1574 .update(cx, |thread, cx| {
1575 thread.send(new_prompt_id(), ["Message 1"], cx)
1576 })
1577 .unwrap();
1578 cx.run_until_parked();
1579 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1580 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1581 language_model::TokenUsage {
1582 input_tokens: 32_000,
1583 output_tokens: 16_000,
1584 cache_creation_input_tokens: 0,
1585 cache_read_input_tokens: 0,
1586 },
1587 ));
1588 fake_model.end_last_completion_stream();
1589 cx.run_until_parked();
1590
1591 let assert_first_message_state = |cx: &mut TestAppContext| {
1592 thread.clone().read_with(cx, |thread, _| {
1593 assert_eq!(
1594 thread.to_markdown(),
1595 indoc! {"
1596 ## User
1597
1598 Message 1
1599
1600 ## Assistant
1601
1602 Message 1 response
1603 "}
1604 );
1605
1606 assert_eq!(
1607 thread.latest_token_usage(),
1608 Some(acp_thread::TokenUsage {
1609 used_tokens: 32_000 + 16_000,
1610 max_tokens: 1_000_000,
1611 })
1612 );
1613 });
1614 };
1615
1616 assert_first_message_state(cx);
1617
1618 let second_message_id = new_prompt_id();
1619 thread
1620 .update(cx, |thread, cx| {
1621 thread.send(second_message_id.clone(), ["Message 2"], cx)
1622 })
1623 .unwrap();
1624 cx.run_until_parked();
1625
1626 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1627 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1628 language_model::TokenUsage {
1629 input_tokens: 40_000,
1630 output_tokens: 20_000,
1631 cache_creation_input_tokens: 0,
1632 cache_read_input_tokens: 0,
1633 },
1634 ));
1635 fake_model.end_last_completion_stream();
1636 cx.run_until_parked();
1637
1638 thread.read_with(cx, |thread, _| {
1639 assert_eq!(
1640 thread.to_markdown(),
1641 indoc! {"
1642 ## User
1643
1644 Message 1
1645
1646 ## Assistant
1647
1648 Message 1 response
1649
1650 ## User
1651
1652 Message 2
1653
1654 ## Assistant
1655
1656 Message 2 response
1657 "}
1658 );
1659
1660 assert_eq!(
1661 thread.latest_token_usage(),
1662 Some(acp_thread::TokenUsage {
1663 used_tokens: 40_000 + 20_000,
1664 max_tokens: 1_000_000,
1665 })
1666 );
1667 });
1668
1669 thread
1670 .update(cx, |thread, cx| thread.rewind(second_message_id, cx))
1671 .unwrap();
1672 cx.run_until_parked();
1673
1674 assert_first_message_state(cx);
1675}
1676
1677#[gpui::test]
1678#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1679async fn test_title_generation(cx: &mut TestAppContext) {
1680 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1681 let fake_model = model.as_fake();
1682
1683 let summary_model = Arc::new(FakeLanguageModel::default());
1684 thread.update(cx, |thread, cx| {
1685 thread.set_summarization_model(Some(summary_model.clone()), cx)
1686 });
1687
1688 let send = thread
1689 .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Hello"], cx))
1690 .unwrap();
1691 cx.run_until_parked();
1692
1693 fake_model.send_last_completion_stream_text_chunk("Hey!");
1694 fake_model.end_last_completion_stream();
1695 cx.run_until_parked();
1696 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1697
1698 // Ensure the summary model has been invoked to generate a title.
1699 summary_model.send_last_completion_stream_text_chunk("Hello ");
1700 summary_model.send_last_completion_stream_text_chunk("world\nG");
1701 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1702 summary_model.end_last_completion_stream();
1703 send.collect::<Vec<_>>().await;
1704 cx.run_until_parked();
1705 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1706
1707 // Send another message, ensuring no title is generated this time.
1708 let send = thread
1709 .update(cx, |thread, cx| {
1710 thread.send(new_prompt_id(), ["Hello again"], cx)
1711 })
1712 .unwrap();
1713 cx.run_until_parked();
1714 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1715 fake_model.end_last_completion_stream();
1716 cx.run_until_parked();
1717 assert_eq!(summary_model.pending_completions(), Vec::new());
1718 send.collect::<Vec<_>>().await;
1719 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1720}
1721
1722#[gpui::test]
1723async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1724 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1725 let fake_model = model.as_fake();
1726
1727 let _events = thread
1728 .update(cx, |thread, cx| {
1729 thread.add_tool(ToolRequiringPermission);
1730 thread.add_tool(EchoTool);
1731 thread.send(new_prompt_id(), ["Hey!"], cx)
1732 })
1733 .unwrap();
1734 cx.run_until_parked();
1735
1736 let permission_tool_use = LanguageModelToolUse {
1737 id: "tool_id_1".into(),
1738 name: ToolRequiringPermission::name().into(),
1739 raw_input: "{}".into(),
1740 input: json!({}),
1741 is_input_complete: true,
1742 };
1743 let echo_tool_use = LanguageModelToolUse {
1744 id: "tool_id_2".into(),
1745 name: EchoTool::name().into(),
1746 raw_input: json!({"text": "test"}).to_string(),
1747 input: json!({"text": "test"}),
1748 is_input_complete: true,
1749 };
1750 fake_model.send_last_completion_stream_text_chunk("Hi!");
1751 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1752 permission_tool_use,
1753 ));
1754 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1755 echo_tool_use.clone(),
1756 ));
1757 fake_model.end_last_completion_stream();
1758 cx.run_until_parked();
1759
1760 // Ensure pending tools are skipped when building a request.
1761 let request = thread
1762 .read_with(cx, |thread, cx| {
1763 thread.build_completion_request(CompletionIntent::EditFile, cx)
1764 })
1765 .unwrap();
1766 assert_eq!(
1767 request.messages[1..],
1768 vec![
1769 LanguageModelRequestMessage {
1770 role: Role::User,
1771 content: vec!["Hey!".into()],
1772 cache: true
1773 },
1774 LanguageModelRequestMessage {
1775 role: Role::Assistant,
1776 content: vec![
1777 MessageContent::Text("Hi!".into()),
1778 MessageContent::ToolUse(echo_tool_use.clone())
1779 ],
1780 cache: false
1781 },
1782 LanguageModelRequestMessage {
1783 role: Role::User,
1784 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1785 tool_use_id: echo_tool_use.id.clone(),
1786 tool_name: echo_tool_use.name,
1787 is_error: false,
1788 content: "test".into(),
1789 output: Some("test".into())
1790 })],
1791 cache: false
1792 },
1793 ],
1794 );
1795}
1796
1797#[gpui::test]
1798async fn test_agent_connection(cx: &mut TestAppContext) {
1799 cx.update(settings::init);
1800 let templates = Templates::new();
1801
1802 // Initialize language model system with test provider
1803 cx.update(|cx| {
1804 gpui_tokio::init(cx);
1805 client::init_settings(cx);
1806
1807 let http_client = FakeHttpClient::with_404_response();
1808 let clock = Arc::new(clock::FakeSystemClock::new());
1809 let client = Client::new(clock, http_client, cx);
1810 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1811 Project::init_settings(cx);
1812 agent_settings::init(cx);
1813 language_model::init(client.clone(), cx);
1814 language_models::init(user_store, client.clone(), cx);
1815 LanguageModelRegistry::test(cx);
1816 });
1817 cx.executor().forbid_parking();
1818
1819 // Create a project for new_thread
1820 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1821 fake_fs.insert_tree(path!("/test"), json!({})).await;
1822 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1823 let cwd = Path::new("/test");
1824 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1825 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1826
1827 // Create agent and connection
1828 let agent = NativeAgent::new(
1829 project.clone(),
1830 history_store,
1831 templates.clone(),
1832 None,
1833 fake_fs.clone(),
1834 &mut cx.to_async(),
1835 )
1836 .await
1837 .unwrap();
1838 let connection = NativeAgentConnection(agent.clone());
1839
1840 // Test model_selector returns Some
1841 let selector_opt = connection.model_selector();
1842 assert!(
1843 selector_opt.is_some(),
1844 "agent2 should always support ModelSelector"
1845 );
1846 let selector = selector_opt.unwrap();
1847
1848 // Test list_models
1849 let listed_models = cx
1850 .update(|cx| selector.list_models(cx))
1851 .await
1852 .expect("list_models should succeed");
1853 let AgentModelList::Grouped(listed_models) = listed_models else {
1854 panic!("Unexpected model list type");
1855 };
1856 assert!(!listed_models.is_empty(), "should have at least one model");
1857 assert_eq!(
1858 listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1859 "fake/fake"
1860 );
1861
1862 // Create a thread using new_thread
1863 let connection_rc = Rc::new(connection.clone());
1864 let acp_thread = cx
1865 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1866 .await
1867 .expect("new_thread should succeed");
1868
1869 // Get the session_id from the AcpThread
1870 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1871
1872 // Test selected_model returns the default
1873 let model = cx
1874 .update(|cx| selector.selected_model(&session_id, cx))
1875 .await
1876 .expect("selected_model should succeed");
1877 let model = cx
1878 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1879 .unwrap();
1880 let model = model.as_fake();
1881 assert_eq!(model.id().0, "fake", "should return default model");
1882
1883 let request = acp_thread.update(cx, |thread, cx| {
1884 thread.send(new_prompt_id(), vec!["abc".into()], cx)
1885 });
1886 cx.run_until_parked();
1887 model.send_last_completion_stream_text_chunk("def");
1888 cx.run_until_parked();
1889 acp_thread.read_with(cx, |thread, cx| {
1890 assert_eq!(
1891 thread.to_markdown(cx),
1892 indoc! {"
1893 ## User
1894
1895 abc
1896
1897 ## Assistant
1898
1899 def
1900
1901 "}
1902 )
1903 });
1904
1905 // Test cancel
1906 cx.update(|cx| connection.cancel(&session_id, cx));
1907 request.await.expect("prompt should fail gracefully");
1908
1909 // Ensure that dropping the ACP thread causes the native thread to be
1910 // dropped as well.
1911 cx.update(|_| drop(acp_thread));
1912 let result = cx
1913 .update(|cx| {
1914 connection.prompt(
1915 acp::PromptRequest {
1916 prompt_id: Some(new_prompt_id()),
1917 session_id: session_id.clone(),
1918 prompt: vec!["ghi".into()],
1919 },
1920 cx,
1921 )
1922 })
1923 .await;
1924 assert_eq!(
1925 result.as_ref().unwrap_err().to_string(),
1926 "Session not found",
1927 "unexpected result: {:?}",
1928 result
1929 );
1930}
1931
1932#[gpui::test]
1933async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1934 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1935 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1936 let fake_model = model.as_fake();
1937
1938 let mut events = thread
1939 .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Think"], cx))
1940 .unwrap();
1941 cx.run_until_parked();
1942
1943 // Simulate streaming partial input.
1944 let input = json!({});
1945 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1946 LanguageModelToolUse {
1947 id: "1".into(),
1948 name: ThinkingTool::name().into(),
1949 raw_input: input.to_string(),
1950 input,
1951 is_input_complete: false,
1952 },
1953 ));
1954
1955 // Input streaming completed
1956 let input = json!({ "content": "Thinking hard!" });
1957 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1958 LanguageModelToolUse {
1959 id: "1".into(),
1960 name: "thinking".into(),
1961 raw_input: input.to_string(),
1962 input,
1963 is_input_complete: true,
1964 },
1965 ));
1966 fake_model.end_last_completion_stream();
1967 cx.run_until_parked();
1968
1969 let tool_call = expect_tool_call(&mut events).await;
1970 assert_eq!(
1971 tool_call,
1972 acp::ToolCall {
1973 id: acp::ToolCallId("1".into()),
1974 title: "Thinking".into(),
1975 kind: acp::ToolKind::Think,
1976 status: acp::ToolCallStatus::Pending,
1977 content: vec![],
1978 locations: vec![],
1979 raw_input: Some(json!({})),
1980 raw_output: None,
1981 }
1982 );
1983 let update = expect_tool_call_update_fields(&mut events).await;
1984 assert_eq!(
1985 update,
1986 acp::ToolCallUpdate {
1987 id: acp::ToolCallId("1".into()),
1988 fields: acp::ToolCallUpdateFields {
1989 title: Some("Thinking".into()),
1990 kind: Some(acp::ToolKind::Think),
1991 raw_input: Some(json!({ "content": "Thinking hard!" })),
1992 ..Default::default()
1993 },
1994 }
1995 );
1996 let update = expect_tool_call_update_fields(&mut events).await;
1997 assert_eq!(
1998 update,
1999 acp::ToolCallUpdate {
2000 id: acp::ToolCallId("1".into()),
2001 fields: acp::ToolCallUpdateFields {
2002 status: Some(acp::ToolCallStatus::InProgress),
2003 ..Default::default()
2004 },
2005 }
2006 );
2007 let update = expect_tool_call_update_fields(&mut events).await;
2008 assert_eq!(
2009 update,
2010 acp::ToolCallUpdate {
2011 id: acp::ToolCallId("1".into()),
2012 fields: acp::ToolCallUpdateFields {
2013 content: Some(vec!["Thinking hard!".into()]),
2014 ..Default::default()
2015 },
2016 }
2017 );
2018 let update = expect_tool_call_update_fields(&mut events).await;
2019 assert_eq!(
2020 update,
2021 acp::ToolCallUpdate {
2022 id: acp::ToolCallId("1".into()),
2023 fields: acp::ToolCallUpdateFields {
2024 status: Some(acp::ToolCallStatus::Completed),
2025 raw_output: Some("Finished thinking.".into()),
2026 ..Default::default()
2027 },
2028 }
2029 );
2030}
2031
2032#[gpui::test]
2033async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2034 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2035 let fake_model = model.as_fake();
2036
2037 let mut events = thread
2038 .update(cx, |thread, cx| {
2039 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2040 thread.send(new_prompt_id(), ["Hello!"], cx)
2041 })
2042 .unwrap();
2043 cx.run_until_parked();
2044
2045 fake_model.send_last_completion_stream_text_chunk("Hey!");
2046 fake_model.end_last_completion_stream();
2047
2048 let mut retry_events = Vec::new();
2049 while let Some(Ok(event)) = events.next().await {
2050 match event {
2051 ThreadEvent::Retry(retry_status) => {
2052 retry_events.push(retry_status);
2053 }
2054 ThreadEvent::Stop(..) => break,
2055 _ => {}
2056 }
2057 }
2058
2059 assert_eq!(retry_events.len(), 0);
2060 thread.read_with(cx, |thread, _cx| {
2061 assert_eq!(
2062 thread.to_markdown(),
2063 indoc! {"
2064 ## User
2065
2066 Hello!
2067
2068 ## Assistant
2069
2070 Hey!
2071 "}
2072 )
2073 });
2074}
2075
2076#[gpui::test]
2077async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2078 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2079 let fake_model = model.as_fake();
2080
2081 let mut events = thread
2082 .update(cx, |thread, cx| {
2083 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2084 thread.send(new_prompt_id(), ["Hello!"], cx)
2085 })
2086 .unwrap();
2087 cx.run_until_parked();
2088
2089 fake_model.send_last_completion_stream_text_chunk("Hey,");
2090 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2091 provider: LanguageModelProviderName::new("Anthropic"),
2092 retry_after: Some(Duration::from_secs(3)),
2093 });
2094 fake_model.end_last_completion_stream();
2095
2096 cx.executor().advance_clock(Duration::from_secs(3));
2097 cx.run_until_parked();
2098
2099 fake_model.send_last_completion_stream_text_chunk("there!");
2100 fake_model.end_last_completion_stream();
2101 cx.run_until_parked();
2102
2103 let mut retry_events = Vec::new();
2104 while let Some(Ok(event)) = events.next().await {
2105 match event {
2106 ThreadEvent::Retry(retry_status) => {
2107 retry_events.push(retry_status);
2108 }
2109 ThreadEvent::Stop(..) => break,
2110 _ => {}
2111 }
2112 }
2113
2114 assert_eq!(retry_events.len(), 1);
2115 assert!(matches!(
2116 retry_events[0],
2117 acp_thread::RetryStatus { attempt: 1, .. }
2118 ));
2119 thread.read_with(cx, |thread, _cx| {
2120 assert_eq!(
2121 thread.to_markdown(),
2122 indoc! {"
2123 ## User
2124
2125 Hello!
2126
2127 ## Assistant
2128
2129 Hey,
2130
2131 [resume]
2132
2133 ## Assistant
2134
2135 there!
2136 "}
2137 )
2138 });
2139}
2140
2141#[gpui::test]
2142async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2143 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2144 let fake_model = model.as_fake();
2145
2146 let events = thread
2147 .update(cx, |thread, cx| {
2148 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2149 thread.add_tool(EchoTool);
2150 thread.send(new_prompt_id(), ["Call the echo tool!"], cx)
2151 })
2152 .unwrap();
2153 cx.run_until_parked();
2154
2155 let tool_use_1 = LanguageModelToolUse {
2156 id: "tool_1".into(),
2157 name: EchoTool::name().into(),
2158 raw_input: json!({"text": "test"}).to_string(),
2159 input: json!({"text": "test"}),
2160 is_input_complete: true,
2161 };
2162 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2163 tool_use_1.clone(),
2164 ));
2165 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2166 provider: LanguageModelProviderName::new("Anthropic"),
2167 retry_after: Some(Duration::from_secs(3)),
2168 });
2169 fake_model.end_last_completion_stream();
2170
2171 cx.executor().advance_clock(Duration::from_secs(3));
2172 let completion = fake_model.pending_completions().pop().unwrap();
2173 assert_eq!(
2174 completion.messages[1..],
2175 vec![
2176 LanguageModelRequestMessage {
2177 role: Role::User,
2178 content: vec!["Call the echo tool!".into()],
2179 cache: false
2180 },
2181 LanguageModelRequestMessage {
2182 role: Role::Assistant,
2183 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2184 cache: false
2185 },
2186 LanguageModelRequestMessage {
2187 role: Role::User,
2188 content: vec![language_model::MessageContent::ToolResult(
2189 LanguageModelToolResult {
2190 tool_use_id: tool_use_1.id.clone(),
2191 tool_name: tool_use_1.name.clone(),
2192 is_error: false,
2193 content: "test".into(),
2194 output: Some("test".into())
2195 }
2196 )],
2197 cache: true
2198 },
2199 ]
2200 );
2201
2202 fake_model.send_last_completion_stream_text_chunk("Done");
2203 fake_model.end_last_completion_stream();
2204 cx.run_until_parked();
2205 events.collect::<Vec<_>>().await;
2206 thread.read_with(cx, |thread, _cx| {
2207 assert_eq!(
2208 thread.last_message(),
2209 Some(Message::Agent(AgentMessage {
2210 content: vec![AgentMessageContent::Text("Done".into())],
2211 tool_results: IndexMap::default()
2212 }))
2213 );
2214 })
2215}
2216
2217#[gpui::test]
2218async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2219 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2220 let fake_model = model.as_fake();
2221
2222 let mut events = thread
2223 .update(cx, |thread, cx| {
2224 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2225 thread.send(new_prompt_id(), ["Hello!"], cx)
2226 })
2227 .unwrap();
2228 cx.run_until_parked();
2229
2230 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2231 fake_model.send_last_completion_stream_error(
2232 LanguageModelCompletionError::ServerOverloaded {
2233 provider: LanguageModelProviderName::new("Anthropic"),
2234 retry_after: Some(Duration::from_secs(3)),
2235 },
2236 );
2237 fake_model.end_last_completion_stream();
2238 cx.executor().advance_clock(Duration::from_secs(3));
2239 cx.run_until_parked();
2240 }
2241
2242 let mut errors = Vec::new();
2243 let mut retry_events = Vec::new();
2244 while let Some(event) = events.next().await {
2245 match event {
2246 Ok(ThreadEvent::Retry(retry_status)) => {
2247 retry_events.push(retry_status);
2248 }
2249 Ok(ThreadEvent::Stop(..)) => break,
2250 Err(error) => errors.push(error),
2251 _ => {}
2252 }
2253 }
2254
2255 assert_eq!(
2256 retry_events.len(),
2257 crate::thread::MAX_RETRY_ATTEMPTS as usize
2258 );
2259 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2260 assert_eq!(retry_events[i].attempt, i + 1);
2261 }
2262 assert_eq!(errors.len(), 1);
2263 let error = errors[0]
2264 .downcast_ref::<LanguageModelCompletionError>()
2265 .unwrap();
2266 assert!(matches!(
2267 error,
2268 LanguageModelCompletionError::ServerOverloaded { .. }
2269 ));
2270}
2271
2272/// Filters out the stop events for asserting against in tests
2273fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2274 result_events
2275 .into_iter()
2276 .filter_map(|event| match event.unwrap() {
2277 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2278 _ => None,
2279 })
2280 .collect()
2281}
2282
2283struct ThreadTest {
2284 model: Arc<dyn LanguageModel>,
2285 thread: Entity<Thread>,
2286 project_context: Entity<ProjectContext>,
2287 context_server_store: Entity<ContextServerStore>,
2288 fs: Arc<FakeFs>,
2289}
2290
2291enum TestModel {
2292 Sonnet4,
2293 Fake,
2294}
2295
2296impl TestModel {
2297 fn id(&self) -> LanguageModelId {
2298 match self {
2299 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2300 TestModel::Fake => unreachable!(),
2301 }
2302 }
2303}
2304
2305async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2306 cx.executor().allow_parking();
2307
2308 let fs = FakeFs::new(cx.background_executor.clone());
2309 fs.create_dir(paths::settings_file().parent().unwrap())
2310 .await
2311 .unwrap();
2312 fs.insert_file(
2313 paths::settings_file(),
2314 json!({
2315 "agent": {
2316 "default_profile": "test-profile",
2317 "profiles": {
2318 "test-profile": {
2319 "name": "Test Profile",
2320 "tools": {
2321 EchoTool::name(): true,
2322 DelayTool::name(): true,
2323 WordListTool::name(): true,
2324 ToolRequiringPermission::name(): true,
2325 InfiniteTool::name(): true,
2326 ThinkingTool::name(): true,
2327 }
2328 }
2329 }
2330 }
2331 })
2332 .to_string()
2333 .into_bytes(),
2334 )
2335 .await;
2336
2337 cx.update(|cx| {
2338 settings::init(cx);
2339 Project::init_settings(cx);
2340 agent_settings::init(cx);
2341 gpui_tokio::init(cx);
2342 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2343 cx.set_http_client(Arc::new(http_client));
2344
2345 client::init_settings(cx);
2346 let client = Client::production(cx);
2347 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2348 language_model::init(client.clone(), cx);
2349 language_models::init(user_store, client.clone(), cx);
2350
2351 watch_settings(fs.clone(), cx);
2352 });
2353
2354 let templates = Templates::new();
2355
2356 fs.insert_tree(path!("/test"), json!({})).await;
2357 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2358
2359 let model = cx
2360 .update(|cx| {
2361 if let TestModel::Fake = model {
2362 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2363 } else {
2364 let model_id = model.id();
2365 let models = LanguageModelRegistry::read_global(cx);
2366 let model = models
2367 .available_models(cx)
2368 .find(|model| model.id() == model_id)
2369 .unwrap();
2370
2371 let provider = models.provider(&model.provider_id()).unwrap();
2372 let authenticated = provider.authenticate(cx);
2373
2374 cx.spawn(async move |_cx| {
2375 authenticated.await.unwrap();
2376 model
2377 })
2378 }
2379 })
2380 .await;
2381
2382 let project_context = cx.new(|_cx| ProjectContext::default());
2383 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2384 let context_server_registry =
2385 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2386 let thread = cx.new(|cx| {
2387 Thread::new(
2388 project,
2389 project_context.clone(),
2390 context_server_registry,
2391 templates,
2392 Some(model.clone()),
2393 cx,
2394 )
2395 });
2396 ThreadTest {
2397 model,
2398 thread,
2399 project_context,
2400 context_server_store,
2401 fs,
2402 }
2403}
2404
2405#[cfg(test)]
2406#[ctor::ctor]
2407fn init_logger() {
2408 if std::env::var("RUST_LOG").is_ok() {
2409 env_logger::init();
2410 }
2411}
2412
2413fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2414 let fs = fs.clone();
2415 cx.spawn({
2416 async move |cx| {
2417 let mut new_settings_content_rx = settings::watch_config_file(
2418 cx.background_executor(),
2419 fs,
2420 paths::settings_file().clone(),
2421 );
2422
2423 while let Some(new_settings_content) = new_settings_content_rx.next().await {
2424 cx.update(|cx| {
2425 SettingsStore::update_global(cx, |settings, cx| {
2426 settings.set_user_settings(&new_settings_content, cx)
2427 })
2428 })
2429 .ok();
2430 }
2431 }
2432 })
2433 .detach();
2434}
2435
2436fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2437 completion
2438 .tools
2439 .iter()
2440 .map(|tool| tool.name.clone())
2441 .collect()
2442}
2443
2444fn setup_context_server(
2445 name: &'static str,
2446 tools: Vec<context_server::types::Tool>,
2447 context_server_store: &Entity<ContextServerStore>,
2448 cx: &mut TestAppContext,
2449) -> mpsc::UnboundedReceiver<(
2450 context_server::types::CallToolParams,
2451 oneshot::Sender<context_server::types::CallToolResponse>,
2452)> {
2453 cx.update(|cx| {
2454 let mut settings = ProjectSettings::get_global(cx).clone();
2455 settings.context_servers.insert(
2456 name.into(),
2457 project::project_settings::ContextServerSettings::Custom {
2458 enabled: true,
2459 command: ContextServerCommand {
2460 path: "somebinary".into(),
2461 args: Vec::new(),
2462 env: None,
2463 },
2464 },
2465 );
2466 ProjectSettings::override_global(settings, cx);
2467 });
2468
2469 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2470 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2471 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2472 context_server::types::InitializeResponse {
2473 protocol_version: context_server::types::ProtocolVersion(
2474 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2475 ),
2476 server_info: context_server::types::Implementation {
2477 name: name.into(),
2478 version: "1.0.0".to_string(),
2479 },
2480 capabilities: context_server::types::ServerCapabilities {
2481 tools: Some(context_server::types::ToolsCapabilities {
2482 list_changed: Some(true),
2483 }),
2484 ..Default::default()
2485 },
2486 meta: None,
2487 }
2488 })
2489 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2490 let tools = tools.clone();
2491 async move {
2492 context_server::types::ListToolsResponse {
2493 tools,
2494 next_cursor: None,
2495 meta: None,
2496 }
2497 }
2498 })
2499 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2500 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2501 async move {
2502 let (response_tx, response_rx) = oneshot::channel();
2503 mcp_tool_calls_tx
2504 .unbounded_send((params, response_tx))
2505 .unwrap();
2506 response_rx.await.unwrap()
2507 }
2508 });
2509 context_server_store.update(cx, |store, cx| {
2510 store.start_server(
2511 Arc::new(ContextServer::new(
2512 ContextServerId(name.into()),
2513 Arc::new(fake_transport),
2514 )),
2515 cx,
2516 );
2517 });
2518 cx.run_until_parked();
2519 mcp_tool_calls_rx
2520}