1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use fs::{FakeFs, Fs};
11use futures::{
12 StreamExt,
13 channel::{
14 mpsc::{self, UnboundedReceiver},
15 oneshot,
16 },
17};
18use gpui::{
19 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
20};
21use indoc::indoc;
22use language_model::{
23 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
24 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
25 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
26 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
27};
28use pretty_assertions::assert_eq;
29use project::{
30 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
31};
32use prompt_store::ProjectContext;
33use reqwest_client::ReqwestClient;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use serde_json::json;
37use settings::{Settings, SettingsStore};
38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
39use util::path;
40
41mod test_tools;
42use test_tools::*;
43
44#[gpui::test]
45async fn test_echo(cx: &mut TestAppContext) {
46 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
47 let fake_model = model.as_fake();
48
49 let events = thread
50 .update(cx, |thread, cx| {
51 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
52 })
53 .unwrap();
54 cx.run_until_parked();
55 fake_model.send_last_completion_stream_text_chunk("Hello");
56 fake_model
57 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
58 fake_model.end_last_completion_stream();
59
60 let events = events.collect().await;
61 thread.update(cx, |thread, _cx| {
62 assert_eq!(
63 thread.last_message().unwrap().to_markdown(),
64 indoc! {"
65 ## Assistant
66
67 Hello
68 "}
69 )
70 });
71 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
72}
73
74#[gpui::test]
75#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
76async fn test_thinking(cx: &mut TestAppContext) {
77 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
78 let fake_model = model.as_fake();
79
80 let events = thread
81 .update(cx, |thread, cx| {
82 thread.send(
83 UserMessageId::new(),
84 [indoc! {"
85 Testing:
86
87 Generate a thinking step where you just think the word 'Think',
88 and have your final answer be 'Hello'
89 "}],
90 cx,
91 )
92 })
93 .unwrap();
94 cx.run_until_parked();
95 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
96 text: "Think".to_string(),
97 signature: None,
98 });
99 fake_model.send_last_completion_stream_text_chunk("Hello");
100 fake_model
101 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
102 fake_model.end_last_completion_stream();
103
104 let events = events.collect().await;
105 thread.update(cx, |thread, _cx| {
106 assert_eq!(
107 thread.last_message().unwrap().to_markdown(),
108 indoc! {"
109 ## Assistant
110
111 <think>Think</think>
112 Hello
113 "}
114 )
115 });
116 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
117}
118
119#[gpui::test]
120async fn test_system_prompt(cx: &mut TestAppContext) {
121 let ThreadTest {
122 model,
123 thread,
124 project_context,
125 ..
126 } = setup(cx, TestModel::Fake).await;
127 let fake_model = model.as_fake();
128
129 project_context.update(cx, |project_context, _cx| {
130 project_context.shell = "test-shell".into()
131 });
132 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
133 thread
134 .update(cx, |thread, cx| {
135 thread.send(UserMessageId::new(), ["abc"], cx)
136 })
137 .unwrap();
138 cx.run_until_parked();
139 let mut pending_completions = fake_model.pending_completions();
140 assert_eq!(
141 pending_completions.len(),
142 1,
143 "unexpected pending completions: {:?}",
144 pending_completions
145 );
146
147 let pending_completion = pending_completions.pop().unwrap();
148 assert_eq!(pending_completion.messages[0].role, Role::System);
149
150 let system_message = &pending_completion.messages[0];
151 let system_prompt = system_message.content[0].to_str().unwrap();
152 assert!(
153 system_prompt.contains("test-shell"),
154 "unexpected system message: {:?}",
155 system_message
156 );
157 assert!(
158 system_prompt.contains("## Fixing Diagnostics"),
159 "unexpected system message: {:?}",
160 system_message
161 );
162}
163
164#[gpui::test]
165async fn test_prompt_caching(cx: &mut TestAppContext) {
166 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
167 let fake_model = model.as_fake();
168
169 // Send initial user message and verify it's cached
170 thread
171 .update(cx, |thread, cx| {
172 thread.send(UserMessageId::new(), ["Message 1"], cx)
173 })
174 .unwrap();
175 cx.run_until_parked();
176
177 let completion = fake_model.pending_completions().pop().unwrap();
178 assert_eq!(
179 completion.messages[1..],
180 vec![LanguageModelRequestMessage {
181 role: Role::User,
182 content: vec!["Message 1".into()],
183 cache: true
184 }]
185 );
186 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
187 "Response to Message 1".into(),
188 ));
189 fake_model.end_last_completion_stream();
190 cx.run_until_parked();
191
192 // Send another user message and verify only the latest is cached
193 thread
194 .update(cx, |thread, cx| {
195 thread.send(UserMessageId::new(), ["Message 2"], cx)
196 })
197 .unwrap();
198 cx.run_until_parked();
199
200 let completion = fake_model.pending_completions().pop().unwrap();
201 assert_eq!(
202 completion.messages[1..],
203 vec![
204 LanguageModelRequestMessage {
205 role: Role::User,
206 content: vec!["Message 1".into()],
207 cache: false
208 },
209 LanguageModelRequestMessage {
210 role: Role::Assistant,
211 content: vec!["Response to Message 1".into()],
212 cache: false
213 },
214 LanguageModelRequestMessage {
215 role: Role::User,
216 content: vec!["Message 2".into()],
217 cache: true
218 }
219 ]
220 );
221 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
222 "Response to Message 2".into(),
223 ));
224 fake_model.end_last_completion_stream();
225 cx.run_until_parked();
226
227 // Simulate a tool call and verify that the latest tool result is cached
228 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
229 thread
230 .update(cx, |thread, cx| {
231 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
232 })
233 .unwrap();
234 cx.run_until_parked();
235
236 let tool_use = LanguageModelToolUse {
237 id: "tool_1".into(),
238 name: EchoTool::name().into(),
239 raw_input: json!({"text": "test"}).to_string(),
240 input: json!({"text": "test"}),
241 is_input_complete: true,
242 };
243 fake_model
244 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
245 fake_model.end_last_completion_stream();
246 cx.run_until_parked();
247
248 let completion = fake_model.pending_completions().pop().unwrap();
249 let tool_result = LanguageModelToolResult {
250 tool_use_id: "tool_1".into(),
251 tool_name: EchoTool::name().into(),
252 is_error: false,
253 content: "test".into(),
254 output: Some("test".into()),
255 };
256 assert_eq!(
257 completion.messages[1..],
258 vec![
259 LanguageModelRequestMessage {
260 role: Role::User,
261 content: vec!["Message 1".into()],
262 cache: false
263 },
264 LanguageModelRequestMessage {
265 role: Role::Assistant,
266 content: vec!["Response to Message 1".into()],
267 cache: false
268 },
269 LanguageModelRequestMessage {
270 role: Role::User,
271 content: vec!["Message 2".into()],
272 cache: false
273 },
274 LanguageModelRequestMessage {
275 role: Role::Assistant,
276 content: vec!["Response to Message 2".into()],
277 cache: false
278 },
279 LanguageModelRequestMessage {
280 role: Role::User,
281 content: vec!["Use the echo tool".into()],
282 cache: false
283 },
284 LanguageModelRequestMessage {
285 role: Role::Assistant,
286 content: vec![MessageContent::ToolUse(tool_use)],
287 cache: false
288 },
289 LanguageModelRequestMessage {
290 role: Role::User,
291 content: vec![MessageContent::ToolResult(tool_result)],
292 cache: true
293 }
294 ]
295 );
296}
297
298#[gpui::test]
299#[cfg_attr(not(feature = "e2e"), ignore)]
300async fn test_basic_tool_calls(cx: &mut TestAppContext) {
301 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
302
303 // Test a tool call that's likely to complete *before* streaming stops.
304 let events = thread
305 .update(cx, |thread, cx| {
306 thread.add_tool(EchoTool);
307 thread.send(
308 UserMessageId::new(),
309 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
310 cx,
311 )
312 })
313 .unwrap()
314 .collect()
315 .await;
316 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
317
318 // Test a tool calls that's likely to complete *after* streaming stops.
319 let events = thread
320 .update(cx, |thread, cx| {
321 thread.remove_tool(&EchoTool::name());
322 thread.add_tool(DelayTool);
323 thread.send(
324 UserMessageId::new(),
325 [
326 "Now call the delay tool with 200ms.",
327 "When the timer goes off, then you echo the output of the tool.",
328 ],
329 cx,
330 )
331 })
332 .unwrap()
333 .collect()
334 .await;
335 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
336 thread.update(cx, |thread, _cx| {
337 assert!(
338 thread
339 .last_message()
340 .unwrap()
341 .as_agent_message()
342 .unwrap()
343 .content
344 .iter()
345 .any(|content| {
346 if let AgentMessageContent::Text(text) = content {
347 text.contains("Ding")
348 } else {
349 false
350 }
351 }),
352 "{}",
353 thread.to_markdown()
354 );
355 });
356}
357
358#[gpui::test]
359#[cfg_attr(not(feature = "e2e"), ignore)]
360async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
361 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
362
363 // Test a tool call that's likely to complete *before* streaming stops.
364 let mut events = thread
365 .update(cx, |thread, cx| {
366 thread.add_tool(WordListTool);
367 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
368 })
369 .unwrap();
370
371 let mut saw_partial_tool_use = false;
372 while let Some(event) = events.next().await {
373 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
374 thread.update(cx, |thread, _cx| {
375 // Look for a tool use in the thread's last message
376 let message = thread.last_message().unwrap();
377 let agent_message = message.as_agent_message().unwrap();
378 let last_content = agent_message.content.last().unwrap();
379 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
380 assert_eq!(last_tool_use.name.as_ref(), "word_list");
381 if tool_call.status == acp::ToolCallStatus::Pending {
382 if !last_tool_use.is_input_complete
383 && last_tool_use.input.get("g").is_none()
384 {
385 saw_partial_tool_use = true;
386 }
387 } else {
388 last_tool_use
389 .input
390 .get("a")
391 .expect("'a' has streamed because input is now complete");
392 last_tool_use
393 .input
394 .get("g")
395 .expect("'g' has streamed because input is now complete");
396 }
397 } else {
398 panic!("last content should be a tool use");
399 }
400 });
401 }
402 }
403
404 assert!(
405 saw_partial_tool_use,
406 "should see at least one partially streamed tool use in the history"
407 );
408}
409
410#[gpui::test]
411async fn test_tool_authorization(cx: &mut TestAppContext) {
412 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
413 let fake_model = model.as_fake();
414
415 let mut events = thread
416 .update(cx, |thread, cx| {
417 thread.add_tool(ToolRequiringPermission);
418 thread.send(UserMessageId::new(), ["abc"], cx)
419 })
420 .unwrap();
421 cx.run_until_parked();
422 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
423 LanguageModelToolUse {
424 id: "tool_id_1".into(),
425 name: ToolRequiringPermission::name().into(),
426 raw_input: "{}".into(),
427 input: json!({}),
428 is_input_complete: true,
429 },
430 ));
431 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
432 LanguageModelToolUse {
433 id: "tool_id_2".into(),
434 name: ToolRequiringPermission::name().into(),
435 raw_input: "{}".into(),
436 input: json!({}),
437 is_input_complete: true,
438 },
439 ));
440 fake_model.end_last_completion_stream();
441 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
442 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
443
444 // Approve the first
445 tool_call_auth_1
446 .response
447 .send(tool_call_auth_1.options[1].id.clone())
448 .unwrap();
449 cx.run_until_parked();
450
451 // Reject the second
452 tool_call_auth_2
453 .response
454 .send(tool_call_auth_1.options[2].id.clone())
455 .unwrap();
456 cx.run_until_parked();
457
458 let completion = fake_model.pending_completions().pop().unwrap();
459 let message = completion.messages.last().unwrap();
460 assert_eq!(
461 message.content,
462 vec![
463 language_model::MessageContent::ToolResult(LanguageModelToolResult {
464 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
465 tool_name: ToolRequiringPermission::name().into(),
466 is_error: false,
467 content: "Allowed".into(),
468 output: Some("Allowed".into())
469 }),
470 language_model::MessageContent::ToolResult(LanguageModelToolResult {
471 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
472 tool_name: ToolRequiringPermission::name().into(),
473 is_error: true,
474 content: "Permission to run tool denied by user".into(),
475 output: Some("Permission to run tool denied by user".into())
476 })
477 ]
478 );
479
480 // Simulate yet another tool call.
481 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
482 LanguageModelToolUse {
483 id: "tool_id_3".into(),
484 name: ToolRequiringPermission::name().into(),
485 raw_input: "{}".into(),
486 input: json!({}),
487 is_input_complete: true,
488 },
489 ));
490 fake_model.end_last_completion_stream();
491
492 // Respond by always allowing tools.
493 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
494 tool_call_auth_3
495 .response
496 .send(tool_call_auth_3.options[0].id.clone())
497 .unwrap();
498 cx.run_until_parked();
499 let completion = fake_model.pending_completions().pop().unwrap();
500 let message = completion.messages.last().unwrap();
501 assert_eq!(
502 message.content,
503 vec![language_model::MessageContent::ToolResult(
504 LanguageModelToolResult {
505 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
506 tool_name: ToolRequiringPermission::name().into(),
507 is_error: false,
508 content: "Allowed".into(),
509 output: Some("Allowed".into())
510 }
511 )]
512 );
513
514 // Simulate a final tool call, ensuring we don't trigger authorization.
515 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
516 LanguageModelToolUse {
517 id: "tool_id_4".into(),
518 name: ToolRequiringPermission::name().into(),
519 raw_input: "{}".into(),
520 input: json!({}),
521 is_input_complete: true,
522 },
523 ));
524 fake_model.end_last_completion_stream();
525 cx.run_until_parked();
526 let completion = fake_model.pending_completions().pop().unwrap();
527 let message = completion.messages.last().unwrap();
528 assert_eq!(
529 message.content,
530 vec![language_model::MessageContent::ToolResult(
531 LanguageModelToolResult {
532 tool_use_id: "tool_id_4".into(),
533 tool_name: ToolRequiringPermission::name().into(),
534 is_error: false,
535 content: "Allowed".into(),
536 output: Some("Allowed".into())
537 }
538 )]
539 );
540}
541
542#[gpui::test]
543async fn test_tool_hallucination(cx: &mut TestAppContext) {
544 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
545 let fake_model = model.as_fake();
546
547 let mut events = thread
548 .update(cx, |thread, cx| {
549 thread.send(UserMessageId::new(), ["abc"], cx)
550 })
551 .unwrap();
552 cx.run_until_parked();
553 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
554 LanguageModelToolUse {
555 id: "tool_id_1".into(),
556 name: "nonexistent_tool".into(),
557 raw_input: "{}".into(),
558 input: json!({}),
559 is_input_complete: true,
560 },
561 ));
562 fake_model.end_last_completion_stream();
563
564 let tool_call = expect_tool_call(&mut events).await;
565 assert_eq!(tool_call.title, "nonexistent_tool");
566 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
567 let update = expect_tool_call_update_fields(&mut events).await;
568 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
569}
570
571#[gpui::test]
572async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
573 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
574 let fake_model = model.as_fake();
575
576 let events = thread
577 .update(cx, |thread, cx| {
578 thread.add_tool(EchoTool);
579 thread.send(UserMessageId::new(), ["abc"], cx)
580 })
581 .unwrap();
582 cx.run_until_parked();
583 let tool_use = LanguageModelToolUse {
584 id: "tool_id_1".into(),
585 name: EchoTool::name().into(),
586 raw_input: "{}".into(),
587 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
588 is_input_complete: true,
589 };
590 fake_model
591 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
592 fake_model.end_last_completion_stream();
593
594 cx.run_until_parked();
595 let completion = fake_model.pending_completions().pop().unwrap();
596 let tool_result = LanguageModelToolResult {
597 tool_use_id: "tool_id_1".into(),
598 tool_name: EchoTool::name().into(),
599 is_error: false,
600 content: "def".into(),
601 output: Some("def".into()),
602 };
603 assert_eq!(
604 completion.messages[1..],
605 vec![
606 LanguageModelRequestMessage {
607 role: Role::User,
608 content: vec!["abc".into()],
609 cache: false
610 },
611 LanguageModelRequestMessage {
612 role: Role::Assistant,
613 content: vec![MessageContent::ToolUse(tool_use.clone())],
614 cache: false
615 },
616 LanguageModelRequestMessage {
617 role: Role::User,
618 content: vec![MessageContent::ToolResult(tool_result.clone())],
619 cache: true
620 },
621 ]
622 );
623
624 // Simulate reaching tool use limit.
625 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
626 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
627 ));
628 fake_model.end_last_completion_stream();
629 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
630 assert!(
631 last_event
632 .unwrap_err()
633 .is::<language_model::ToolUseLimitReachedError>()
634 );
635
636 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
637 cx.run_until_parked();
638 let completion = fake_model.pending_completions().pop().unwrap();
639 assert_eq!(
640 completion.messages[1..],
641 vec![
642 LanguageModelRequestMessage {
643 role: Role::User,
644 content: vec!["abc".into()],
645 cache: false
646 },
647 LanguageModelRequestMessage {
648 role: Role::Assistant,
649 content: vec![MessageContent::ToolUse(tool_use)],
650 cache: false
651 },
652 LanguageModelRequestMessage {
653 role: Role::User,
654 content: vec![MessageContent::ToolResult(tool_result)],
655 cache: false
656 },
657 LanguageModelRequestMessage {
658 role: Role::User,
659 content: vec!["Continue where you left off".into()],
660 cache: true
661 }
662 ]
663 );
664
665 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
666 fake_model.end_last_completion_stream();
667 events.collect::<Vec<_>>().await;
668 thread.read_with(cx, |thread, _cx| {
669 assert_eq!(
670 thread.last_message().unwrap().to_markdown(),
671 indoc! {"
672 ## Assistant
673
674 Done
675 "}
676 )
677 });
678}
679
680#[gpui::test]
681async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
682 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
683 let fake_model = model.as_fake();
684
685 let events = thread
686 .update(cx, |thread, cx| {
687 thread.add_tool(EchoTool);
688 thread.send(UserMessageId::new(), ["abc"], cx)
689 })
690 .unwrap();
691 cx.run_until_parked();
692
693 let tool_use = LanguageModelToolUse {
694 id: "tool_id_1".into(),
695 name: EchoTool::name().into(),
696 raw_input: "{}".into(),
697 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
698 is_input_complete: true,
699 };
700 let tool_result = LanguageModelToolResult {
701 tool_use_id: "tool_id_1".into(),
702 tool_name: EchoTool::name().into(),
703 is_error: false,
704 content: "def".into(),
705 output: Some("def".into()),
706 };
707 fake_model
708 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
709 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
710 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
711 ));
712 fake_model.end_last_completion_stream();
713 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
714 assert!(
715 last_event
716 .unwrap_err()
717 .is::<language_model::ToolUseLimitReachedError>()
718 );
719
720 thread
721 .update(cx, |thread, cx| {
722 thread.send(UserMessageId::new(), vec!["ghi"], cx)
723 })
724 .unwrap();
725 cx.run_until_parked();
726 let completion = fake_model.pending_completions().pop().unwrap();
727 assert_eq!(
728 completion.messages[1..],
729 vec![
730 LanguageModelRequestMessage {
731 role: Role::User,
732 content: vec!["abc".into()],
733 cache: false
734 },
735 LanguageModelRequestMessage {
736 role: Role::Assistant,
737 content: vec![MessageContent::ToolUse(tool_use)],
738 cache: false
739 },
740 LanguageModelRequestMessage {
741 role: Role::User,
742 content: vec![MessageContent::ToolResult(tool_result)],
743 cache: false
744 },
745 LanguageModelRequestMessage {
746 role: Role::User,
747 content: vec!["ghi".into()],
748 cache: true
749 }
750 ]
751 );
752}
753
754async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
755 let event = events
756 .next()
757 .await
758 .expect("no tool call authorization event received")
759 .unwrap();
760 match event {
761 ThreadEvent::ToolCall(tool_call) => tool_call,
762 event => {
763 panic!("Unexpected event {event:?}");
764 }
765 }
766}
767
768async fn expect_tool_call_update_fields(
769 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
770) -> acp::ToolCallUpdate {
771 let event = events
772 .next()
773 .await
774 .expect("no tool call authorization event received")
775 .unwrap();
776 match event {
777 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
778 event => {
779 panic!("Unexpected event {event:?}");
780 }
781 }
782}
783
784async fn next_tool_call_authorization(
785 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
786) -> ToolCallAuthorization {
787 loop {
788 let event = events
789 .next()
790 .await
791 .expect("no tool call authorization event received")
792 .unwrap();
793 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
794 let permission_kinds = tool_call_authorization
795 .options
796 .iter()
797 .map(|o| o.kind)
798 .collect::<Vec<_>>();
799 assert_eq!(
800 permission_kinds,
801 vec![
802 acp::PermissionOptionKind::AllowAlways,
803 acp::PermissionOptionKind::AllowOnce,
804 acp::PermissionOptionKind::RejectOnce,
805 ]
806 );
807 return tool_call_authorization;
808 }
809 }
810}
811
812#[gpui::test]
813#[cfg_attr(not(feature = "e2e"), ignore)]
814async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
815 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
816
817 // Test concurrent tool calls with different delay times
818 let events = thread
819 .update(cx, |thread, cx| {
820 thread.add_tool(DelayTool);
821 thread.send(
822 UserMessageId::new(),
823 [
824 "Call the delay tool twice in the same message.",
825 "Once with 100ms. Once with 300ms.",
826 "When both timers are complete, describe the outputs.",
827 ],
828 cx,
829 )
830 })
831 .unwrap()
832 .collect()
833 .await;
834
835 let stop_reasons = stop_events(events);
836 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
837
838 thread.update(cx, |thread, _cx| {
839 let last_message = thread.last_message().unwrap();
840 let agent_message = last_message.as_agent_message().unwrap();
841 let text = agent_message
842 .content
843 .iter()
844 .filter_map(|content| {
845 if let AgentMessageContent::Text(text) = content {
846 Some(text.as_str())
847 } else {
848 None
849 }
850 })
851 .collect::<String>();
852
853 assert!(text.contains("Ding"));
854 });
855}
856
857#[gpui::test]
858async fn test_profiles(cx: &mut TestAppContext) {
859 let ThreadTest {
860 model, thread, fs, ..
861 } = setup(cx, TestModel::Fake).await;
862 let fake_model = model.as_fake();
863
864 thread.update(cx, |thread, _cx| {
865 thread.add_tool(DelayTool);
866 thread.add_tool(EchoTool);
867 thread.add_tool(InfiniteTool);
868 });
869
870 // Override profiles and wait for settings to be loaded.
871 fs.insert_file(
872 paths::settings_file(),
873 json!({
874 "agent": {
875 "profiles": {
876 "test-1": {
877 "name": "Test Profile 1",
878 "tools": {
879 EchoTool::name(): true,
880 DelayTool::name(): true,
881 }
882 },
883 "test-2": {
884 "name": "Test Profile 2",
885 "tools": {
886 InfiniteTool::name(): true,
887 }
888 }
889 }
890 }
891 })
892 .to_string()
893 .into_bytes(),
894 )
895 .await;
896 cx.run_until_parked();
897
898 // Test that test-1 profile (default) has echo and delay tools
899 thread
900 .update(cx, |thread, cx| {
901 thread.set_profile(AgentProfileId("test-1".into()));
902 thread.send(UserMessageId::new(), ["test"], cx)
903 })
904 .unwrap();
905 cx.run_until_parked();
906
907 let mut pending_completions = fake_model.pending_completions();
908 assert_eq!(pending_completions.len(), 1);
909 let completion = pending_completions.pop().unwrap();
910 let tool_names: Vec<String> = completion
911 .tools
912 .iter()
913 .map(|tool| tool.name.clone())
914 .collect();
915 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
916 fake_model.end_last_completion_stream();
917
918 // Switch to test-2 profile, and verify that it has only the infinite tool.
919 thread
920 .update(cx, |thread, cx| {
921 thread.set_profile(AgentProfileId("test-2".into()));
922 thread.send(UserMessageId::new(), ["test2"], cx)
923 })
924 .unwrap();
925 cx.run_until_parked();
926 let mut pending_completions = fake_model.pending_completions();
927 assert_eq!(pending_completions.len(), 1);
928 let completion = pending_completions.pop().unwrap();
929 let tool_names: Vec<String> = completion
930 .tools
931 .iter()
932 .map(|tool| tool.name.clone())
933 .collect();
934 assert_eq!(tool_names, vec![InfiniteTool::name()]);
935}
936
937#[gpui::test]
938async fn test_mcp_tools(cx: &mut TestAppContext) {
939 let ThreadTest {
940 model,
941 thread,
942 context_server_store,
943 fs,
944 ..
945 } = setup(cx, TestModel::Fake).await;
946 let fake_model = model.as_fake();
947
948 // Override profiles and wait for settings to be loaded.
949 fs.insert_file(
950 paths::settings_file(),
951 json!({
952 "agent": {
953 "profiles": {
954 "test": {
955 "name": "Test Profile",
956 "enable_all_context_servers": true,
957 "tools": {
958 EchoTool::name(): true,
959 }
960 },
961 }
962 }
963 })
964 .to_string()
965 .into_bytes(),
966 )
967 .await;
968 cx.run_until_parked();
969 thread.update(cx, |thread, _| {
970 thread.set_profile(AgentProfileId("test".into()))
971 });
972
973 let mut mcp_tool_calls = setup_context_server(
974 "test_server",
975 vec![context_server::types::Tool {
976 name: "echo".into(),
977 description: None,
978 input_schema: serde_json::to_value(
979 EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
980 )
981 .unwrap(),
982 output_schema: None,
983 annotations: None,
984 }],
985 &context_server_store,
986 cx,
987 );
988
989 let events = thread.update(cx, |thread, cx| {
990 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
991 });
992 cx.run_until_parked();
993
994 // Simulate the model calling the MCP tool.
995 let completion = fake_model.pending_completions().pop().unwrap();
996 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
997 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
998 LanguageModelToolUse {
999 id: "tool_1".into(),
1000 name: "echo".into(),
1001 raw_input: json!({"text": "test"}).to_string(),
1002 input: json!({"text": "test"}),
1003 is_input_complete: true,
1004 },
1005 ));
1006 fake_model.end_last_completion_stream();
1007 cx.run_until_parked();
1008
1009 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1010 assert_eq!(tool_call_params.name, "echo");
1011 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1012 tool_call_response
1013 .send(context_server::types::CallToolResponse {
1014 content: vec![context_server::types::ToolResponseContent::Text {
1015 text: "test".into(),
1016 }],
1017 is_error: None,
1018 meta: None,
1019 structured_content: None,
1020 })
1021 .unwrap();
1022 cx.run_until_parked();
1023
1024 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1025 fake_model.send_last_completion_stream_text_chunk("Done!");
1026 fake_model.end_last_completion_stream();
1027 events.collect::<Vec<_>>().await;
1028
1029 // Send again after adding the echo tool, ensuring the name collision is resolved.
1030 let events = thread.update(cx, |thread, cx| {
1031 thread.add_tool(EchoTool);
1032 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1033 });
1034 cx.run_until_parked();
1035 let completion = fake_model.pending_completions().pop().unwrap();
1036 assert_eq!(
1037 tool_names_for_completion(&completion),
1038 vec!["echo", "test_server_echo"]
1039 );
1040 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1041 LanguageModelToolUse {
1042 id: "tool_2".into(),
1043 name: "test_server_echo".into(),
1044 raw_input: json!({"text": "mcp"}).to_string(),
1045 input: json!({"text": "mcp"}),
1046 is_input_complete: true,
1047 },
1048 ));
1049 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1050 LanguageModelToolUse {
1051 id: "tool_3".into(),
1052 name: "echo".into(),
1053 raw_input: json!({"text": "native"}).to_string(),
1054 input: json!({"text": "native"}),
1055 is_input_complete: true,
1056 },
1057 ));
1058 fake_model.end_last_completion_stream();
1059 cx.run_until_parked();
1060
1061 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1062 assert_eq!(tool_call_params.name, "echo");
1063 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1064 tool_call_response
1065 .send(context_server::types::CallToolResponse {
1066 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1067 is_error: None,
1068 meta: None,
1069 structured_content: None,
1070 })
1071 .unwrap();
1072 cx.run_until_parked();
1073
1074 // Ensure the tool results were inserted with the correct names.
1075 let completion = fake_model.pending_completions().pop().unwrap();
1076 assert_eq!(
1077 completion.messages.last().unwrap().content,
1078 vec![
1079 MessageContent::ToolResult(LanguageModelToolResult {
1080 tool_use_id: "tool_3".into(),
1081 tool_name: "echo".into(),
1082 is_error: false,
1083 content: "native".into(),
1084 output: Some("native".into()),
1085 },),
1086 MessageContent::ToolResult(LanguageModelToolResult {
1087 tool_use_id: "tool_2".into(),
1088 tool_name: "test_server_echo".into(),
1089 is_error: false,
1090 content: "mcp".into(),
1091 output: Some("mcp".into()),
1092 },),
1093 ]
1094 );
1095 fake_model.end_last_completion_stream();
1096 events.collect::<Vec<_>>().await;
1097}
1098
1099#[gpui::test]
1100async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1101 let ThreadTest {
1102 model,
1103 thread,
1104 context_server_store,
1105 fs,
1106 ..
1107 } = setup(cx, TestModel::Fake).await;
1108 let fake_model = model.as_fake();
1109
1110 // Set up a profile with all tools enabled
1111 fs.insert_file(
1112 paths::settings_file(),
1113 json!({
1114 "agent": {
1115 "profiles": {
1116 "test": {
1117 "name": "Test Profile",
1118 "enable_all_context_servers": true,
1119 "tools": {
1120 EchoTool::name(): true,
1121 DelayTool::name(): true,
1122 WordListTool::name(): true,
1123 ToolRequiringPermission::name(): true,
1124 InfiniteTool::name(): true,
1125 }
1126 },
1127 }
1128 }
1129 })
1130 .to_string()
1131 .into_bytes(),
1132 )
1133 .await;
1134 cx.run_until_parked();
1135
1136 thread.update(cx, |thread, _| {
1137 thread.set_profile(AgentProfileId("test".into()));
1138 thread.add_tool(EchoTool);
1139 thread.add_tool(DelayTool);
1140 thread.add_tool(WordListTool);
1141 thread.add_tool(ToolRequiringPermission);
1142 thread.add_tool(InfiniteTool);
1143 });
1144
1145 // Set up multiple context servers with some overlapping tool names
1146 let _server1_calls = setup_context_server(
1147 "xxx",
1148 vec![
1149 context_server::types::Tool {
1150 name: "echo".into(), // Conflicts with native EchoTool
1151 description: None,
1152 input_schema: serde_json::to_value(
1153 EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1154 )
1155 .unwrap(),
1156 output_schema: None,
1157 annotations: None,
1158 },
1159 context_server::types::Tool {
1160 name: "unique_tool_1".into(),
1161 description: None,
1162 input_schema: json!({"type": "object", "properties": {}}),
1163 output_schema: None,
1164 annotations: None,
1165 },
1166 ],
1167 &context_server_store,
1168 cx,
1169 );
1170
1171 let _server2_calls = setup_context_server(
1172 "yyy",
1173 vec![
1174 context_server::types::Tool {
1175 name: "echo".into(), // Also conflicts with native EchoTool
1176 description: None,
1177 input_schema: serde_json::to_value(
1178 EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1179 )
1180 .unwrap(),
1181 output_schema: None,
1182 annotations: None,
1183 },
1184 context_server::types::Tool {
1185 name: "unique_tool_2".into(),
1186 description: None,
1187 input_schema: json!({"type": "object", "properties": {}}),
1188 output_schema: None,
1189 annotations: None,
1190 },
1191 context_server::types::Tool {
1192 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1193 description: None,
1194 input_schema: json!({"type": "object", "properties": {}}),
1195 output_schema: None,
1196 annotations: None,
1197 },
1198 context_server::types::Tool {
1199 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1200 description: None,
1201 input_schema: json!({"type": "object", "properties": {}}),
1202 output_schema: None,
1203 annotations: None,
1204 },
1205 ],
1206 &context_server_store,
1207 cx,
1208 );
1209 let _server3_calls = setup_context_server(
1210 "zzz",
1211 vec![
1212 context_server::types::Tool {
1213 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1214 description: None,
1215 input_schema: json!({"type": "object", "properties": {}}),
1216 output_schema: None,
1217 annotations: None,
1218 },
1219 context_server::types::Tool {
1220 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1221 description: None,
1222 input_schema: json!({"type": "object", "properties": {}}),
1223 output_schema: None,
1224 annotations: None,
1225 },
1226 context_server::types::Tool {
1227 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1228 description: None,
1229 input_schema: json!({"type": "object", "properties": {}}),
1230 output_schema: None,
1231 annotations: None,
1232 },
1233 ],
1234 &context_server_store,
1235 cx,
1236 );
1237
1238 thread
1239 .update(cx, |thread, cx| {
1240 thread.send(UserMessageId::new(), ["Go"], cx)
1241 })
1242 .unwrap();
1243 cx.run_until_parked();
1244 let completion = fake_model.pending_completions().pop().unwrap();
1245 assert_eq!(
1246 tool_names_for_completion(&completion),
1247 vec![
1248 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1249 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1250 "delay",
1251 "echo",
1252 "infinite",
1253 "tool_requiring_permission",
1254 "unique_tool_1",
1255 "unique_tool_2",
1256 "word_list",
1257 "xxx_echo",
1258 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1259 "yyy_echo",
1260 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1261 ]
1262 );
1263}
1264
1265#[gpui::test]
1266#[cfg_attr(not(feature = "e2e"), ignore)]
1267async fn test_cancellation(cx: &mut TestAppContext) {
1268 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1269
1270 let mut events = thread
1271 .update(cx, |thread, cx| {
1272 thread.add_tool(InfiniteTool);
1273 thread.add_tool(EchoTool);
1274 thread.send(
1275 UserMessageId::new(),
1276 ["Call the echo tool, then call the infinite tool, then explain their output"],
1277 cx,
1278 )
1279 })
1280 .unwrap();
1281
1282 // Wait until both tools are called.
1283 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1284 let mut echo_id = None;
1285 let mut echo_completed = false;
1286 while let Some(event) = events.next().await {
1287 match event.unwrap() {
1288 ThreadEvent::ToolCall(tool_call) => {
1289 assert_eq!(tool_call.title, expected_tools.remove(0));
1290 if tool_call.title == "Echo" {
1291 echo_id = Some(tool_call.id);
1292 }
1293 }
1294 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1295 acp::ToolCallUpdate {
1296 id,
1297 fields:
1298 acp::ToolCallUpdateFields {
1299 status: Some(acp::ToolCallStatus::Completed),
1300 ..
1301 },
1302 },
1303 )) if Some(&id) == echo_id.as_ref() => {
1304 echo_completed = true;
1305 }
1306 _ => {}
1307 }
1308
1309 if expected_tools.is_empty() && echo_completed {
1310 break;
1311 }
1312 }
1313
1314 // Cancel the current send and ensure that the event stream is closed, even
1315 // if one of the tools is still running.
1316 thread.update(cx, |thread, cx| thread.cancel(cx));
1317 let events = events.collect::<Vec<_>>().await;
1318 let last_event = events.last();
1319 assert!(
1320 matches!(
1321 last_event,
1322 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1323 ),
1324 "unexpected event {last_event:?}"
1325 );
1326
1327 // Ensure we can still send a new message after cancellation.
1328 let events = thread
1329 .update(cx, |thread, cx| {
1330 thread.send(
1331 UserMessageId::new(),
1332 ["Testing: reply with 'Hello' then stop."],
1333 cx,
1334 )
1335 })
1336 .unwrap()
1337 .collect::<Vec<_>>()
1338 .await;
1339 thread.update(cx, |thread, _cx| {
1340 let message = thread.last_message().unwrap();
1341 let agent_message = message.as_agent_message().unwrap();
1342 assert_eq!(
1343 agent_message.content,
1344 vec![AgentMessageContent::Text("Hello".to_string())]
1345 );
1346 });
1347 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1348}
1349
1350#[gpui::test]
1351#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1352async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1353 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1354 let fake_model = model.as_fake();
1355
1356 let events_1 = thread
1357 .update(cx, |thread, cx| {
1358 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1359 })
1360 .unwrap();
1361 cx.run_until_parked();
1362 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1363 cx.run_until_parked();
1364
1365 let events_2 = thread
1366 .update(cx, |thread, cx| {
1367 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1368 })
1369 .unwrap();
1370 cx.run_until_parked();
1371 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1372 fake_model
1373 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1374 fake_model.end_last_completion_stream();
1375
1376 let events_1 = events_1.collect::<Vec<_>>().await;
1377 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1378 let events_2 = events_2.collect::<Vec<_>>().await;
1379 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1380}
1381
1382#[gpui::test]
1383async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1384 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1385 let fake_model = model.as_fake();
1386
1387 let events_1 = thread
1388 .update(cx, |thread, cx| {
1389 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1390 })
1391 .unwrap();
1392 cx.run_until_parked();
1393 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1394 fake_model
1395 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1396 fake_model.end_last_completion_stream();
1397 let events_1 = events_1.collect::<Vec<_>>().await;
1398
1399 let events_2 = thread
1400 .update(cx, |thread, cx| {
1401 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1402 })
1403 .unwrap();
1404 cx.run_until_parked();
1405 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1406 fake_model
1407 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1408 fake_model.end_last_completion_stream();
1409 let events_2 = events_2.collect::<Vec<_>>().await;
1410
1411 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1412 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1413}
1414
1415#[gpui::test]
1416async fn test_refusal(cx: &mut TestAppContext) {
1417 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1418 let fake_model = model.as_fake();
1419
1420 let events = thread
1421 .update(cx, |thread, cx| {
1422 thread.send(UserMessageId::new(), ["Hello"], cx)
1423 })
1424 .unwrap();
1425 cx.run_until_parked();
1426 thread.read_with(cx, |thread, _| {
1427 assert_eq!(
1428 thread.to_markdown(),
1429 indoc! {"
1430 ## User
1431
1432 Hello
1433 "}
1434 );
1435 });
1436
1437 fake_model.send_last_completion_stream_text_chunk("Hey!");
1438 cx.run_until_parked();
1439 thread.read_with(cx, |thread, _| {
1440 assert_eq!(
1441 thread.to_markdown(),
1442 indoc! {"
1443 ## User
1444
1445 Hello
1446
1447 ## Assistant
1448
1449 Hey!
1450 "}
1451 );
1452 });
1453
1454 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1455 fake_model
1456 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1457 let events = events.collect::<Vec<_>>().await;
1458 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1459 thread.read_with(cx, |thread, _| {
1460 assert_eq!(thread.to_markdown(), "");
1461 });
1462}
1463
1464#[gpui::test]
1465async fn test_truncate_first_message(cx: &mut TestAppContext) {
1466 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1467 let fake_model = model.as_fake();
1468
1469 let message_id = UserMessageId::new();
1470 thread
1471 .update(cx, |thread, cx| {
1472 thread.send(message_id.clone(), ["Hello"], cx)
1473 })
1474 .unwrap();
1475 cx.run_until_parked();
1476 thread.read_with(cx, |thread, _| {
1477 assert_eq!(
1478 thread.to_markdown(),
1479 indoc! {"
1480 ## User
1481
1482 Hello
1483 "}
1484 );
1485 assert_eq!(thread.latest_token_usage(), None);
1486 });
1487
1488 fake_model.send_last_completion_stream_text_chunk("Hey!");
1489 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1490 language_model::TokenUsage {
1491 input_tokens: 32_000,
1492 output_tokens: 16_000,
1493 cache_creation_input_tokens: 0,
1494 cache_read_input_tokens: 0,
1495 },
1496 ));
1497 cx.run_until_parked();
1498 thread.read_with(cx, |thread, _| {
1499 assert_eq!(
1500 thread.to_markdown(),
1501 indoc! {"
1502 ## User
1503
1504 Hello
1505
1506 ## Assistant
1507
1508 Hey!
1509 "}
1510 );
1511 assert_eq!(
1512 thread.latest_token_usage(),
1513 Some(acp_thread::TokenUsage {
1514 used_tokens: 32_000 + 16_000,
1515 max_tokens: 1_000_000,
1516 })
1517 );
1518 });
1519
1520 thread
1521 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1522 .unwrap();
1523 cx.run_until_parked();
1524 thread.read_with(cx, |thread, _| {
1525 assert_eq!(thread.to_markdown(), "");
1526 assert_eq!(thread.latest_token_usage(), None);
1527 });
1528
1529 // Ensure we can still send a new message after truncation.
1530 thread
1531 .update(cx, |thread, cx| {
1532 thread.send(UserMessageId::new(), ["Hi"], cx)
1533 })
1534 .unwrap();
1535 thread.update(cx, |thread, _cx| {
1536 assert_eq!(
1537 thread.to_markdown(),
1538 indoc! {"
1539 ## User
1540
1541 Hi
1542 "}
1543 );
1544 });
1545 cx.run_until_parked();
1546 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1547 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1548 language_model::TokenUsage {
1549 input_tokens: 40_000,
1550 output_tokens: 20_000,
1551 cache_creation_input_tokens: 0,
1552 cache_read_input_tokens: 0,
1553 },
1554 ));
1555 cx.run_until_parked();
1556 thread.read_with(cx, |thread, _| {
1557 assert_eq!(
1558 thread.to_markdown(),
1559 indoc! {"
1560 ## User
1561
1562 Hi
1563
1564 ## Assistant
1565
1566 Ahoy!
1567 "}
1568 );
1569
1570 assert_eq!(
1571 thread.latest_token_usage(),
1572 Some(acp_thread::TokenUsage {
1573 used_tokens: 40_000 + 20_000,
1574 max_tokens: 1_000_000,
1575 })
1576 );
1577 });
1578}
1579
1580#[gpui::test]
1581async fn test_truncate_second_message(cx: &mut TestAppContext) {
1582 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1583 let fake_model = model.as_fake();
1584
1585 thread
1586 .update(cx, |thread, cx| {
1587 thread.send(UserMessageId::new(), ["Message 1"], cx)
1588 })
1589 .unwrap();
1590 cx.run_until_parked();
1591 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1592 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1593 language_model::TokenUsage {
1594 input_tokens: 32_000,
1595 output_tokens: 16_000,
1596 cache_creation_input_tokens: 0,
1597 cache_read_input_tokens: 0,
1598 },
1599 ));
1600 fake_model.end_last_completion_stream();
1601 cx.run_until_parked();
1602
1603 let assert_first_message_state = |cx: &mut TestAppContext| {
1604 thread.clone().read_with(cx, |thread, _| {
1605 assert_eq!(
1606 thread.to_markdown(),
1607 indoc! {"
1608 ## User
1609
1610 Message 1
1611
1612 ## Assistant
1613
1614 Message 1 response
1615 "}
1616 );
1617
1618 assert_eq!(
1619 thread.latest_token_usage(),
1620 Some(acp_thread::TokenUsage {
1621 used_tokens: 32_000 + 16_000,
1622 max_tokens: 1_000_000,
1623 })
1624 );
1625 });
1626 };
1627
1628 assert_first_message_state(cx);
1629
1630 let second_message_id = UserMessageId::new();
1631 thread
1632 .update(cx, |thread, cx| {
1633 thread.send(second_message_id.clone(), ["Message 2"], cx)
1634 })
1635 .unwrap();
1636 cx.run_until_parked();
1637
1638 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1639 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1640 language_model::TokenUsage {
1641 input_tokens: 40_000,
1642 output_tokens: 20_000,
1643 cache_creation_input_tokens: 0,
1644 cache_read_input_tokens: 0,
1645 },
1646 ));
1647 fake_model.end_last_completion_stream();
1648 cx.run_until_parked();
1649
1650 thread.read_with(cx, |thread, _| {
1651 assert_eq!(
1652 thread.to_markdown(),
1653 indoc! {"
1654 ## User
1655
1656 Message 1
1657
1658 ## Assistant
1659
1660 Message 1 response
1661
1662 ## User
1663
1664 Message 2
1665
1666 ## Assistant
1667
1668 Message 2 response
1669 "}
1670 );
1671
1672 assert_eq!(
1673 thread.latest_token_usage(),
1674 Some(acp_thread::TokenUsage {
1675 used_tokens: 40_000 + 20_000,
1676 max_tokens: 1_000_000,
1677 })
1678 );
1679 });
1680
1681 thread
1682 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1683 .unwrap();
1684 cx.run_until_parked();
1685
1686 assert_first_message_state(cx);
1687}
1688
1689#[gpui::test]
1690#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1691async fn test_title_generation(cx: &mut TestAppContext) {
1692 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1693 let fake_model = model.as_fake();
1694
1695 let summary_model = Arc::new(FakeLanguageModel::default());
1696 thread.update(cx, |thread, cx| {
1697 thread.set_summarization_model(Some(summary_model.clone()), cx)
1698 });
1699
1700 let send = thread
1701 .update(cx, |thread, cx| {
1702 thread.send(UserMessageId::new(), ["Hello"], cx)
1703 })
1704 .unwrap();
1705 cx.run_until_parked();
1706
1707 fake_model.send_last_completion_stream_text_chunk("Hey!");
1708 fake_model.end_last_completion_stream();
1709 cx.run_until_parked();
1710 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1711
1712 // Ensure the summary model has been invoked to generate a title.
1713 summary_model.send_last_completion_stream_text_chunk("Hello ");
1714 summary_model.send_last_completion_stream_text_chunk("world\nG");
1715 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1716 summary_model.end_last_completion_stream();
1717 send.collect::<Vec<_>>().await;
1718 cx.run_until_parked();
1719 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1720
1721 // Send another message, ensuring no title is generated this time.
1722 let send = thread
1723 .update(cx, |thread, cx| {
1724 thread.send(UserMessageId::new(), ["Hello again"], cx)
1725 })
1726 .unwrap();
1727 cx.run_until_parked();
1728 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1729 fake_model.end_last_completion_stream();
1730 cx.run_until_parked();
1731 assert_eq!(summary_model.pending_completions(), Vec::new());
1732 send.collect::<Vec<_>>().await;
1733 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1734}
1735
1736#[gpui::test]
1737async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1738 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1739 let fake_model = model.as_fake();
1740
1741 let _events = thread
1742 .update(cx, |thread, cx| {
1743 thread.add_tool(ToolRequiringPermission);
1744 thread.add_tool(EchoTool);
1745 thread.send(UserMessageId::new(), ["Hey!"], cx)
1746 })
1747 .unwrap();
1748 cx.run_until_parked();
1749
1750 let permission_tool_use = LanguageModelToolUse {
1751 id: "tool_id_1".into(),
1752 name: ToolRequiringPermission::name().into(),
1753 raw_input: "{}".into(),
1754 input: json!({}),
1755 is_input_complete: true,
1756 };
1757 let echo_tool_use = LanguageModelToolUse {
1758 id: "tool_id_2".into(),
1759 name: EchoTool::name().into(),
1760 raw_input: json!({"text": "test"}).to_string(),
1761 input: json!({"text": "test"}),
1762 is_input_complete: true,
1763 };
1764 fake_model.send_last_completion_stream_text_chunk("Hi!");
1765 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1766 permission_tool_use,
1767 ));
1768 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1769 echo_tool_use.clone(),
1770 ));
1771 fake_model.end_last_completion_stream();
1772 cx.run_until_parked();
1773
1774 // Ensure pending tools are skipped when building a request.
1775 let request = thread
1776 .read_with(cx, |thread, cx| {
1777 thread.build_completion_request(CompletionIntent::EditFile, cx)
1778 })
1779 .unwrap();
1780 assert_eq!(
1781 request.messages[1..],
1782 vec![
1783 LanguageModelRequestMessage {
1784 role: Role::User,
1785 content: vec!["Hey!".into()],
1786 cache: true
1787 },
1788 LanguageModelRequestMessage {
1789 role: Role::Assistant,
1790 content: vec![
1791 MessageContent::Text("Hi!".into()),
1792 MessageContent::ToolUse(echo_tool_use.clone())
1793 ],
1794 cache: false
1795 },
1796 LanguageModelRequestMessage {
1797 role: Role::User,
1798 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1799 tool_use_id: echo_tool_use.id.clone(),
1800 tool_name: echo_tool_use.name,
1801 is_error: false,
1802 content: "test".into(),
1803 output: Some("test".into())
1804 })],
1805 cache: false
1806 },
1807 ],
1808 );
1809}
1810
1811#[gpui::test]
1812async fn test_agent_connection(cx: &mut TestAppContext) {
1813 cx.update(settings::init);
1814 let templates = Templates::new();
1815
1816 // Initialize language model system with test provider
1817 cx.update(|cx| {
1818 gpui_tokio::init(cx);
1819 client::init_settings(cx);
1820
1821 let http_client = FakeHttpClient::with_404_response();
1822 let clock = Arc::new(clock::FakeSystemClock::new());
1823 let client = Client::new(clock, http_client, cx);
1824 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1825 Project::init_settings(cx);
1826 agent_settings::init(cx);
1827 language_model::init(client.clone(), cx);
1828 language_models::init(user_store, client.clone(), cx);
1829 LanguageModelRegistry::test(cx);
1830 });
1831 cx.executor().forbid_parking();
1832
1833 // Create a project for new_thread
1834 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1835 fake_fs.insert_tree(path!("/test"), json!({})).await;
1836 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1837 let cwd = Path::new("/test");
1838 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1839 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1840
1841 // Create agent and connection
1842 let agent = NativeAgent::new(
1843 project.clone(),
1844 history_store,
1845 templates.clone(),
1846 None,
1847 fake_fs.clone(),
1848 &mut cx.to_async(),
1849 )
1850 .await
1851 .unwrap();
1852 let connection = NativeAgentConnection(agent.clone());
1853
1854 // Test model_selector returns Some
1855 let selector_opt = connection.model_selector();
1856 assert!(
1857 selector_opt.is_some(),
1858 "agent2 should always support ModelSelector"
1859 );
1860 let selector = selector_opt.unwrap();
1861
1862 // Test list_models
1863 let listed_models = cx
1864 .update(|cx| selector.list_models(cx))
1865 .await
1866 .expect("list_models should succeed");
1867 let AgentModelList::Grouped(listed_models) = listed_models else {
1868 panic!("Unexpected model list type");
1869 };
1870 assert!(!listed_models.is_empty(), "should have at least one model");
1871 assert_eq!(
1872 listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1873 "fake/fake"
1874 );
1875
1876 // Create a thread using new_thread
1877 let connection_rc = Rc::new(connection.clone());
1878 let acp_thread = cx
1879 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1880 .await
1881 .expect("new_thread should succeed");
1882
1883 // Get the session_id from the AcpThread
1884 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1885
1886 // Test selected_model returns the default
1887 let model = cx
1888 .update(|cx| selector.selected_model(&session_id, cx))
1889 .await
1890 .expect("selected_model should succeed");
1891 let model = cx
1892 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1893 .unwrap();
1894 let model = model.as_fake();
1895 assert_eq!(model.id().0, "fake", "should return default model");
1896
1897 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1898 cx.run_until_parked();
1899 model.send_last_completion_stream_text_chunk("def");
1900 cx.run_until_parked();
1901 acp_thread.read_with(cx, |thread, cx| {
1902 assert_eq!(
1903 thread.to_markdown(cx),
1904 indoc! {"
1905 ## User
1906
1907 abc
1908
1909 ## Assistant
1910
1911 def
1912
1913 "}
1914 )
1915 });
1916
1917 // Test cancel
1918 cx.update(|cx| connection.cancel(&session_id, cx));
1919 request.await.expect("prompt should fail gracefully");
1920
1921 // Ensure that dropping the ACP thread causes the native thread to be
1922 // dropped as well.
1923 cx.update(|_| drop(acp_thread));
1924 let result = cx
1925 .update(|cx| {
1926 connection.prompt(
1927 Some(acp_thread::UserMessageId::new()),
1928 acp::PromptRequest {
1929 session_id: session_id.clone(),
1930 prompt: vec!["ghi".into()],
1931 },
1932 cx,
1933 )
1934 })
1935 .await;
1936 assert_eq!(
1937 result.as_ref().unwrap_err().to_string(),
1938 "Session not found",
1939 "unexpected result: {:?}",
1940 result
1941 );
1942}
1943
1944#[gpui::test]
1945async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1946 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1947 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1948 let fake_model = model.as_fake();
1949
1950 let mut events = thread
1951 .update(cx, |thread, cx| {
1952 thread.send(UserMessageId::new(), ["Think"], cx)
1953 })
1954 .unwrap();
1955 cx.run_until_parked();
1956
1957 // Simulate streaming partial input.
1958 let input = json!({});
1959 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1960 LanguageModelToolUse {
1961 id: "1".into(),
1962 name: ThinkingTool::name().into(),
1963 raw_input: input.to_string(),
1964 input,
1965 is_input_complete: false,
1966 },
1967 ));
1968
1969 // Input streaming completed
1970 let input = json!({ "content": "Thinking hard!" });
1971 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1972 LanguageModelToolUse {
1973 id: "1".into(),
1974 name: "thinking".into(),
1975 raw_input: input.to_string(),
1976 input,
1977 is_input_complete: true,
1978 },
1979 ));
1980 fake_model.end_last_completion_stream();
1981 cx.run_until_parked();
1982
1983 let tool_call = expect_tool_call(&mut events).await;
1984 assert_eq!(
1985 tool_call,
1986 acp::ToolCall {
1987 id: acp::ToolCallId("1".into()),
1988 title: "Thinking".into(),
1989 kind: acp::ToolKind::Think,
1990 status: acp::ToolCallStatus::Pending,
1991 content: vec![],
1992 locations: vec![],
1993 raw_input: Some(json!({})),
1994 raw_output: None,
1995 }
1996 );
1997 let update = expect_tool_call_update_fields(&mut events).await;
1998 assert_eq!(
1999 update,
2000 acp::ToolCallUpdate {
2001 id: acp::ToolCallId("1".into()),
2002 fields: acp::ToolCallUpdateFields {
2003 title: Some("Thinking".into()),
2004 kind: Some(acp::ToolKind::Think),
2005 raw_input: Some(json!({ "content": "Thinking hard!" })),
2006 ..Default::default()
2007 },
2008 }
2009 );
2010 let update = expect_tool_call_update_fields(&mut events).await;
2011 assert_eq!(
2012 update,
2013 acp::ToolCallUpdate {
2014 id: acp::ToolCallId("1".into()),
2015 fields: acp::ToolCallUpdateFields {
2016 status: Some(acp::ToolCallStatus::InProgress),
2017 ..Default::default()
2018 },
2019 }
2020 );
2021 let update = expect_tool_call_update_fields(&mut events).await;
2022 assert_eq!(
2023 update,
2024 acp::ToolCallUpdate {
2025 id: acp::ToolCallId("1".into()),
2026 fields: acp::ToolCallUpdateFields {
2027 content: Some(vec!["Thinking hard!".into()]),
2028 ..Default::default()
2029 },
2030 }
2031 );
2032 let update = expect_tool_call_update_fields(&mut events).await;
2033 assert_eq!(
2034 update,
2035 acp::ToolCallUpdate {
2036 id: acp::ToolCallId("1".into()),
2037 fields: acp::ToolCallUpdateFields {
2038 status: Some(acp::ToolCallStatus::Completed),
2039 raw_output: Some("Finished thinking.".into()),
2040 ..Default::default()
2041 },
2042 }
2043 );
2044}
2045
2046#[gpui::test]
2047async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2048 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2049 let fake_model = model.as_fake();
2050
2051 let mut events = thread
2052 .update(cx, |thread, cx| {
2053 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2054 thread.send(UserMessageId::new(), ["Hello!"], cx)
2055 })
2056 .unwrap();
2057 cx.run_until_parked();
2058
2059 fake_model.send_last_completion_stream_text_chunk("Hey!");
2060 fake_model.end_last_completion_stream();
2061
2062 let mut retry_events = Vec::new();
2063 while let Some(Ok(event)) = events.next().await {
2064 match event {
2065 ThreadEvent::Retry(retry_status) => {
2066 retry_events.push(retry_status);
2067 }
2068 ThreadEvent::Stop(..) => break,
2069 _ => {}
2070 }
2071 }
2072
2073 assert_eq!(retry_events.len(), 0);
2074 thread.read_with(cx, |thread, _cx| {
2075 assert_eq!(
2076 thread.to_markdown(),
2077 indoc! {"
2078 ## User
2079
2080 Hello!
2081
2082 ## Assistant
2083
2084 Hey!
2085 "}
2086 )
2087 });
2088}
2089
2090#[gpui::test]
2091async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2092 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2093 let fake_model = model.as_fake();
2094
2095 let mut events = thread
2096 .update(cx, |thread, cx| {
2097 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2098 thread.send(UserMessageId::new(), ["Hello!"], cx)
2099 })
2100 .unwrap();
2101 cx.run_until_parked();
2102
2103 fake_model.send_last_completion_stream_text_chunk("Hey,");
2104 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2105 provider: LanguageModelProviderName::new("Anthropic"),
2106 retry_after: Some(Duration::from_secs(3)),
2107 });
2108 fake_model.end_last_completion_stream();
2109
2110 cx.executor().advance_clock(Duration::from_secs(3));
2111 cx.run_until_parked();
2112
2113 fake_model.send_last_completion_stream_text_chunk("there!");
2114 fake_model.end_last_completion_stream();
2115 cx.run_until_parked();
2116
2117 let mut retry_events = Vec::new();
2118 while let Some(Ok(event)) = events.next().await {
2119 match event {
2120 ThreadEvent::Retry(retry_status) => {
2121 retry_events.push(retry_status);
2122 }
2123 ThreadEvent::Stop(..) => break,
2124 _ => {}
2125 }
2126 }
2127
2128 assert_eq!(retry_events.len(), 1);
2129 assert!(matches!(
2130 retry_events[0],
2131 acp_thread::RetryStatus { attempt: 1, .. }
2132 ));
2133 thread.read_with(cx, |thread, _cx| {
2134 assert_eq!(
2135 thread.to_markdown(),
2136 indoc! {"
2137 ## User
2138
2139 Hello!
2140
2141 ## Assistant
2142
2143 Hey,
2144
2145 [resume]
2146
2147 ## Assistant
2148
2149 there!
2150 "}
2151 )
2152 });
2153}
2154
2155#[gpui::test]
2156async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2157 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2158 let fake_model = model.as_fake();
2159
2160 let events = thread
2161 .update(cx, |thread, cx| {
2162 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2163 thread.add_tool(EchoTool);
2164 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2165 })
2166 .unwrap();
2167 cx.run_until_parked();
2168
2169 let tool_use_1 = LanguageModelToolUse {
2170 id: "tool_1".into(),
2171 name: EchoTool::name().into(),
2172 raw_input: json!({"text": "test"}).to_string(),
2173 input: json!({"text": "test"}),
2174 is_input_complete: true,
2175 };
2176 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2177 tool_use_1.clone(),
2178 ));
2179 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2180 provider: LanguageModelProviderName::new("Anthropic"),
2181 retry_after: Some(Duration::from_secs(3)),
2182 });
2183 fake_model.end_last_completion_stream();
2184
2185 cx.executor().advance_clock(Duration::from_secs(3));
2186 let completion = fake_model.pending_completions().pop().unwrap();
2187 assert_eq!(
2188 completion.messages[1..],
2189 vec![
2190 LanguageModelRequestMessage {
2191 role: Role::User,
2192 content: vec!["Call the echo tool!".into()],
2193 cache: false
2194 },
2195 LanguageModelRequestMessage {
2196 role: Role::Assistant,
2197 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2198 cache: false
2199 },
2200 LanguageModelRequestMessage {
2201 role: Role::User,
2202 content: vec![language_model::MessageContent::ToolResult(
2203 LanguageModelToolResult {
2204 tool_use_id: tool_use_1.id.clone(),
2205 tool_name: tool_use_1.name.clone(),
2206 is_error: false,
2207 content: "test".into(),
2208 output: Some("test".into())
2209 }
2210 )],
2211 cache: true
2212 },
2213 ]
2214 );
2215
2216 fake_model.send_last_completion_stream_text_chunk("Done");
2217 fake_model.end_last_completion_stream();
2218 cx.run_until_parked();
2219 events.collect::<Vec<_>>().await;
2220 thread.read_with(cx, |thread, _cx| {
2221 assert_eq!(
2222 thread.last_message(),
2223 Some(Message::Agent(AgentMessage {
2224 content: vec![AgentMessageContent::Text("Done".into())],
2225 tool_results: IndexMap::default()
2226 }))
2227 );
2228 })
2229}
2230
2231#[gpui::test]
2232async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2233 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2234 let fake_model = model.as_fake();
2235
2236 let mut events = thread
2237 .update(cx, |thread, cx| {
2238 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2239 thread.send(UserMessageId::new(), ["Hello!"], cx)
2240 })
2241 .unwrap();
2242 cx.run_until_parked();
2243
2244 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2245 fake_model.send_last_completion_stream_error(
2246 LanguageModelCompletionError::ServerOverloaded {
2247 provider: LanguageModelProviderName::new("Anthropic"),
2248 retry_after: Some(Duration::from_secs(3)),
2249 },
2250 );
2251 fake_model.end_last_completion_stream();
2252 cx.executor().advance_clock(Duration::from_secs(3));
2253 cx.run_until_parked();
2254 }
2255
2256 let mut errors = Vec::new();
2257 let mut retry_events = Vec::new();
2258 while let Some(event) = events.next().await {
2259 match event {
2260 Ok(ThreadEvent::Retry(retry_status)) => {
2261 retry_events.push(retry_status);
2262 }
2263 Ok(ThreadEvent::Stop(..)) => break,
2264 Err(error) => errors.push(error),
2265 _ => {}
2266 }
2267 }
2268
2269 assert_eq!(
2270 retry_events.len(),
2271 crate::thread::MAX_RETRY_ATTEMPTS as usize
2272 );
2273 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2274 assert_eq!(retry_events[i].attempt, i + 1);
2275 }
2276 assert_eq!(errors.len(), 1);
2277 let error = errors[0]
2278 .downcast_ref::<LanguageModelCompletionError>()
2279 .unwrap();
2280 assert!(matches!(
2281 error,
2282 LanguageModelCompletionError::ServerOverloaded { .. }
2283 ));
2284}
2285
2286/// Filters out the stop events for asserting against in tests
2287fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2288 result_events
2289 .into_iter()
2290 .filter_map(|event| match event.unwrap() {
2291 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2292 _ => None,
2293 })
2294 .collect()
2295}
2296
2297struct ThreadTest {
2298 model: Arc<dyn LanguageModel>,
2299 thread: Entity<Thread>,
2300 project_context: Entity<ProjectContext>,
2301 context_server_store: Entity<ContextServerStore>,
2302 fs: Arc<FakeFs>,
2303}
2304
2305enum TestModel {
2306 Sonnet4,
2307 Fake,
2308}
2309
2310impl TestModel {
2311 fn id(&self) -> LanguageModelId {
2312 match self {
2313 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2314 TestModel::Fake => unreachable!(),
2315 }
2316 }
2317}
2318
2319async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2320 cx.executor().allow_parking();
2321
2322 let fs = FakeFs::new(cx.background_executor.clone());
2323 fs.create_dir(paths::settings_file().parent().unwrap())
2324 .await
2325 .unwrap();
2326 fs.insert_file(
2327 paths::settings_file(),
2328 json!({
2329 "agent": {
2330 "default_profile": "test-profile",
2331 "profiles": {
2332 "test-profile": {
2333 "name": "Test Profile",
2334 "tools": {
2335 EchoTool::name(): true,
2336 DelayTool::name(): true,
2337 WordListTool::name(): true,
2338 ToolRequiringPermission::name(): true,
2339 InfiniteTool::name(): true,
2340 ThinkingTool::name(): true,
2341 }
2342 }
2343 }
2344 }
2345 })
2346 .to_string()
2347 .into_bytes(),
2348 )
2349 .await;
2350
2351 cx.update(|cx| {
2352 settings::init(cx);
2353 Project::init_settings(cx);
2354 agent_settings::init(cx);
2355 gpui_tokio::init(cx);
2356 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2357 cx.set_http_client(Arc::new(http_client));
2358
2359 client::init_settings(cx);
2360 let client = Client::production(cx);
2361 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2362 language_model::init(client.clone(), cx);
2363 language_models::init(user_store, client.clone(), cx);
2364
2365 watch_settings(fs.clone(), cx);
2366 });
2367
2368 let templates = Templates::new();
2369
2370 fs.insert_tree(path!("/test"), json!({})).await;
2371 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2372
2373 let model = cx
2374 .update(|cx| {
2375 if let TestModel::Fake = model {
2376 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2377 } else {
2378 let model_id = model.id();
2379 let models = LanguageModelRegistry::read_global(cx);
2380 let model = models
2381 .available_models(cx)
2382 .find(|model| model.id() == model_id)
2383 .unwrap();
2384
2385 let provider = models.provider(&model.provider_id()).unwrap();
2386 let authenticated = provider.authenticate(cx);
2387
2388 cx.spawn(async move |_cx| {
2389 authenticated.await.unwrap();
2390 model
2391 })
2392 }
2393 })
2394 .await;
2395
2396 let project_context = cx.new(|_cx| ProjectContext::default());
2397 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2398 let context_server_registry =
2399 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2400 let thread = cx.new(|cx| {
2401 Thread::new(
2402 project,
2403 project_context.clone(),
2404 context_server_registry,
2405 templates,
2406 Some(model.clone()),
2407 cx,
2408 )
2409 });
2410 ThreadTest {
2411 model,
2412 thread,
2413 project_context,
2414 context_server_store,
2415 fs,
2416 }
2417}
2418
2419#[cfg(test)]
2420#[ctor::ctor]
2421fn init_logger() {
2422 if std::env::var("RUST_LOG").is_ok() {
2423 env_logger::init();
2424 }
2425}
2426
2427fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2428 let fs = fs.clone();
2429 cx.spawn({
2430 async move |cx| {
2431 let mut new_settings_content_rx = settings::watch_config_file(
2432 cx.background_executor(),
2433 fs,
2434 paths::settings_file().clone(),
2435 );
2436
2437 while let Some(new_settings_content) = new_settings_content_rx.next().await {
2438 cx.update(|cx| {
2439 SettingsStore::update_global(cx, |settings, cx| {
2440 settings.set_user_settings(&new_settings_content, cx)
2441 })
2442 })
2443 .ok();
2444 }
2445 }
2446 })
2447 .detach();
2448}
2449
2450fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2451 completion
2452 .tools
2453 .iter()
2454 .map(|tool| tool.name.clone())
2455 .collect()
2456}
2457
2458fn setup_context_server(
2459 name: &'static str,
2460 tools: Vec<context_server::types::Tool>,
2461 context_server_store: &Entity<ContextServerStore>,
2462 cx: &mut TestAppContext,
2463) -> mpsc::UnboundedReceiver<(
2464 context_server::types::CallToolParams,
2465 oneshot::Sender<context_server::types::CallToolResponse>,
2466)> {
2467 cx.update(|cx| {
2468 let mut settings = ProjectSettings::get_global(cx).clone();
2469 settings.context_servers.insert(
2470 name.into(),
2471 project::project_settings::ContextServerSettings::Custom {
2472 enabled: true,
2473 command: ContextServerCommand {
2474 path: "somebinary".into(),
2475 args: Vec::new(),
2476 env: None,
2477 },
2478 },
2479 );
2480 ProjectSettings::override_global(settings, cx);
2481 });
2482
2483 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2484 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2485 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2486 context_server::types::InitializeResponse {
2487 protocol_version: context_server::types::ProtocolVersion(
2488 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2489 ),
2490 server_info: context_server::types::Implementation {
2491 name: name.into(),
2492 version: "1.0.0".to_string(),
2493 },
2494 capabilities: context_server::types::ServerCapabilities {
2495 tools: Some(context_server::types::ToolsCapabilities {
2496 list_changed: Some(true),
2497 }),
2498 ..Default::default()
2499 },
2500 meta: None,
2501 }
2502 })
2503 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2504 let tools = tools.clone();
2505 async move {
2506 context_server::types::ListToolsResponse {
2507 tools,
2508 next_cursor: None,
2509 meta: None,
2510 }
2511 }
2512 })
2513 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2514 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2515 async move {
2516 let (response_tx, response_rx) = oneshot::channel();
2517 mcp_tool_calls_tx
2518 .unbounded_send((params, response_tx))
2519 .unwrap();
2520 response_rx.await.unwrap()
2521 }
2522 });
2523 context_server_store.update(cx, |store, cx| {
2524 store.start_server(
2525 Arc::new(ContextServer::new(
2526 ContextServerId(name.into()),
2527 Arc::new(fake_transport),
2528 )),
2529 cx,
2530 );
2531 });
2532 cx.run_until_parked();
2533 mcp_tool_calls_rx
2534}