1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use fs::{FakeFs, Fs};
11use futures::{
12 StreamExt,
13 channel::{
14 mpsc::{self, UnboundedReceiver},
15 oneshot,
16 },
17};
18use gpui::{
19 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
20};
21use indoc::indoc;
22use language_model::{
23 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
24 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
25 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
26 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
27};
28use pretty_assertions::assert_eq;
29use project::{
30 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
31};
32use prompt_store::ProjectContext;
33use reqwest_client::ReqwestClient;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use serde_json::json;
37use settings::{Settings, SettingsStore};
38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
39use util::path;
40
41mod test_tools;
42use test_tools::*;
43
44#[gpui::test]
45async fn test_echo(cx: &mut TestAppContext) {
46 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
47 let fake_model = model.as_fake();
48
49 let events = thread
50 .update(cx, |thread, cx| {
51 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
52 })
53 .unwrap();
54 cx.run_until_parked();
55 fake_model.send_last_completion_stream_text_chunk("Hello");
56 fake_model
57 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
58 fake_model.end_last_completion_stream();
59
60 let events = events.collect().await;
61 thread.update(cx, |thread, _cx| {
62 assert_eq!(
63 thread.last_message().unwrap().to_markdown(),
64 indoc! {"
65 ## Assistant
66
67 Hello
68 "}
69 )
70 });
71 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
72}
73
74#[gpui::test]
75#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
76async fn test_thinking(cx: &mut TestAppContext) {
77 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
78 let fake_model = model.as_fake();
79
80 let events = thread
81 .update(cx, |thread, cx| {
82 thread.send(
83 UserMessageId::new(),
84 [indoc! {"
85 Testing:
86
87 Generate a thinking step where you just think the word 'Think',
88 and have your final answer be 'Hello'
89 "}],
90 cx,
91 )
92 })
93 .unwrap();
94 cx.run_until_parked();
95 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
96 text: "Think".to_string(),
97 signature: None,
98 });
99 fake_model.send_last_completion_stream_text_chunk("Hello");
100 fake_model
101 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
102 fake_model.end_last_completion_stream();
103
104 let events = events.collect().await;
105 thread.update(cx, |thread, _cx| {
106 assert_eq!(
107 thread.last_message().unwrap().to_markdown(),
108 indoc! {"
109 ## Assistant
110
111 <think>Think</think>
112 Hello
113 "}
114 )
115 });
116 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
117}
118
119#[gpui::test]
120async fn test_system_prompt(cx: &mut TestAppContext) {
121 let ThreadTest {
122 model,
123 thread,
124 project_context,
125 ..
126 } = setup(cx, TestModel::Fake).await;
127 let fake_model = model.as_fake();
128
129 project_context.update(cx, |project_context, _cx| {
130 project_context.shell = "test-shell".into()
131 });
132 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
133 thread
134 .update(cx, |thread, cx| {
135 thread.send(UserMessageId::new(), ["abc"], cx)
136 })
137 .unwrap();
138 cx.run_until_parked();
139 let mut pending_completions = fake_model.pending_completions();
140 assert_eq!(
141 pending_completions.len(),
142 1,
143 "unexpected pending completions: {:?}",
144 pending_completions
145 );
146
147 let pending_completion = pending_completions.pop().unwrap();
148 assert_eq!(pending_completion.messages[0].role, Role::System);
149
150 let system_message = &pending_completion.messages[0];
151 let system_prompt = system_message.content[0].to_str().unwrap();
152 assert!(
153 system_prompt.contains("test-shell"),
154 "unexpected system message: {:?}",
155 system_message
156 );
157 assert!(
158 system_prompt.contains("## Fixing Diagnostics"),
159 "unexpected system message: {:?}",
160 system_message
161 );
162}
163
164#[gpui::test]
165async fn test_prompt_caching(cx: &mut TestAppContext) {
166 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
167 let fake_model = model.as_fake();
168
169 // Send initial user message and verify it's cached
170 thread
171 .update(cx, |thread, cx| {
172 thread.send(UserMessageId::new(), ["Message 1"], cx)
173 })
174 .unwrap();
175 cx.run_until_parked();
176
177 let completion = fake_model.pending_completions().pop().unwrap();
178 assert_eq!(
179 completion.messages[1..],
180 vec![LanguageModelRequestMessage {
181 role: Role::User,
182 content: vec!["Message 1".into()],
183 cache: true
184 }]
185 );
186 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
187 "Response to Message 1".into(),
188 ));
189 fake_model.end_last_completion_stream();
190 cx.run_until_parked();
191
192 // Send another user message and verify only the latest is cached
193 thread
194 .update(cx, |thread, cx| {
195 thread.send(UserMessageId::new(), ["Message 2"], cx)
196 })
197 .unwrap();
198 cx.run_until_parked();
199
200 let completion = fake_model.pending_completions().pop().unwrap();
201 assert_eq!(
202 completion.messages[1..],
203 vec![
204 LanguageModelRequestMessage {
205 role: Role::User,
206 content: vec!["Message 1".into()],
207 cache: false
208 },
209 LanguageModelRequestMessage {
210 role: Role::Assistant,
211 content: vec!["Response to Message 1".into()],
212 cache: false
213 },
214 LanguageModelRequestMessage {
215 role: Role::User,
216 content: vec!["Message 2".into()],
217 cache: true
218 }
219 ]
220 );
221 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
222 "Response to Message 2".into(),
223 ));
224 fake_model.end_last_completion_stream();
225 cx.run_until_parked();
226
227 // Simulate a tool call and verify that the latest tool result is cached
228 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
229 thread
230 .update(cx, |thread, cx| {
231 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
232 })
233 .unwrap();
234 cx.run_until_parked();
235
236 let tool_use = LanguageModelToolUse {
237 id: "tool_1".into(),
238 name: EchoTool::name().into(),
239 raw_input: json!({"text": "test"}).to_string(),
240 input: json!({"text": "test"}),
241 is_input_complete: true,
242 };
243 fake_model
244 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
245 fake_model.end_last_completion_stream();
246 cx.run_until_parked();
247
248 let completion = fake_model.pending_completions().pop().unwrap();
249 let tool_result = LanguageModelToolResult {
250 tool_use_id: "tool_1".into(),
251 tool_name: EchoTool::name().into(),
252 is_error: false,
253 content: "test".into(),
254 output: Some("test".into()),
255 };
256 assert_eq!(
257 completion.messages[1..],
258 vec![
259 LanguageModelRequestMessage {
260 role: Role::User,
261 content: vec!["Message 1".into()],
262 cache: false
263 },
264 LanguageModelRequestMessage {
265 role: Role::Assistant,
266 content: vec!["Response to Message 1".into()],
267 cache: false
268 },
269 LanguageModelRequestMessage {
270 role: Role::User,
271 content: vec!["Message 2".into()],
272 cache: false
273 },
274 LanguageModelRequestMessage {
275 role: Role::Assistant,
276 content: vec!["Response to Message 2".into()],
277 cache: false
278 },
279 LanguageModelRequestMessage {
280 role: Role::User,
281 content: vec!["Use the echo tool".into()],
282 cache: false
283 },
284 LanguageModelRequestMessage {
285 role: Role::Assistant,
286 content: vec![MessageContent::ToolUse(tool_use)],
287 cache: false
288 },
289 LanguageModelRequestMessage {
290 role: Role::User,
291 content: vec![MessageContent::ToolResult(tool_result)],
292 cache: true
293 }
294 ]
295 );
296}
297
298#[gpui::test]
299#[cfg_attr(not(feature = "e2e"), ignore)]
300async fn test_basic_tool_calls(cx: &mut TestAppContext) {
301 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
302
303 // Test a tool call that's likely to complete *before* streaming stops.
304 let events = thread
305 .update(cx, |thread, cx| {
306 thread.add_tool(EchoTool);
307 thread.send(
308 UserMessageId::new(),
309 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
310 cx,
311 )
312 })
313 .unwrap()
314 .collect()
315 .await;
316 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
317
318 // Test a tool calls that's likely to complete *after* streaming stops.
319 let events = thread
320 .update(cx, |thread, cx| {
321 thread.remove_tool(&EchoTool::name());
322 thread.add_tool(DelayTool);
323 thread.send(
324 UserMessageId::new(),
325 [
326 "Now call the delay tool with 200ms.",
327 "When the timer goes off, then you echo the output of the tool.",
328 ],
329 cx,
330 )
331 })
332 .unwrap()
333 .collect()
334 .await;
335 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
336 thread.update(cx, |thread, _cx| {
337 assert!(
338 thread
339 .last_message()
340 .unwrap()
341 .as_agent_message()
342 .unwrap()
343 .content
344 .iter()
345 .any(|content| {
346 if let AgentMessageContent::Text(text) = content {
347 text.contains("Ding")
348 } else {
349 false
350 }
351 }),
352 "{}",
353 thread.to_markdown()
354 );
355 });
356}
357
358#[gpui::test]
359#[cfg_attr(not(feature = "e2e"), ignore)]
360async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
361 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
362
363 // Test a tool call that's likely to complete *before* streaming stops.
364 let mut events = thread
365 .update(cx, |thread, cx| {
366 thread.add_tool(WordListTool);
367 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
368 })
369 .unwrap();
370
371 let mut saw_partial_tool_use = false;
372 while let Some(event) = events.next().await {
373 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
374 thread.update(cx, |thread, _cx| {
375 // Look for a tool use in the thread's last message
376 let message = thread.last_message().unwrap();
377 let agent_message = message.as_agent_message().unwrap();
378 let last_content = agent_message.content.last().unwrap();
379 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
380 assert_eq!(last_tool_use.name.as_ref(), "word_list");
381 if tool_call.status == acp::ToolCallStatus::Pending {
382 if !last_tool_use.is_input_complete
383 && last_tool_use.input.get("g").is_none()
384 {
385 saw_partial_tool_use = true;
386 }
387 } else {
388 last_tool_use
389 .input
390 .get("a")
391 .expect("'a' has streamed because input is now complete");
392 last_tool_use
393 .input
394 .get("g")
395 .expect("'g' has streamed because input is now complete");
396 }
397 } else {
398 panic!("last content should be a tool use");
399 }
400 });
401 }
402 }
403
404 assert!(
405 saw_partial_tool_use,
406 "should see at least one partially streamed tool use in the history"
407 );
408}
409
410#[gpui::test]
411async fn test_tool_authorization(cx: &mut TestAppContext) {
412 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
413 let fake_model = model.as_fake();
414
415 let mut events = thread
416 .update(cx, |thread, cx| {
417 thread.add_tool(ToolRequiringPermission);
418 thread.send(UserMessageId::new(), ["abc"], cx)
419 })
420 .unwrap();
421 cx.run_until_parked();
422 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
423 LanguageModelToolUse {
424 id: "tool_id_1".into(),
425 name: ToolRequiringPermission::name().into(),
426 raw_input: "{}".into(),
427 input: json!({}),
428 is_input_complete: true,
429 },
430 ));
431 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
432 LanguageModelToolUse {
433 id: "tool_id_2".into(),
434 name: ToolRequiringPermission::name().into(),
435 raw_input: "{}".into(),
436 input: json!({}),
437 is_input_complete: true,
438 },
439 ));
440 fake_model.end_last_completion_stream();
441 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
442 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
443
444 // Approve the first
445 tool_call_auth_1
446 .response
447 .send(tool_call_auth_1.options[1].id.clone())
448 .unwrap();
449 cx.run_until_parked();
450
451 // Reject the second
452 tool_call_auth_2
453 .response
454 .send(tool_call_auth_1.options[2].id.clone())
455 .unwrap();
456 cx.run_until_parked();
457
458 let completion = fake_model.pending_completions().pop().unwrap();
459 let message = completion.messages.last().unwrap();
460 assert_eq!(
461 message.content,
462 vec![
463 language_model::MessageContent::ToolResult(LanguageModelToolResult {
464 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
465 tool_name: ToolRequiringPermission::name().into(),
466 is_error: false,
467 content: "Allowed".into(),
468 output: Some("Allowed".into())
469 }),
470 language_model::MessageContent::ToolResult(LanguageModelToolResult {
471 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
472 tool_name: ToolRequiringPermission::name().into(),
473 is_error: true,
474 content: "Permission to run tool denied by user".into(),
475 output: Some("Permission to run tool denied by user".into())
476 })
477 ]
478 );
479
480 // Simulate yet another tool call.
481 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
482 LanguageModelToolUse {
483 id: "tool_id_3".into(),
484 name: ToolRequiringPermission::name().into(),
485 raw_input: "{}".into(),
486 input: json!({}),
487 is_input_complete: true,
488 },
489 ));
490 fake_model.end_last_completion_stream();
491
492 // Respond by always allowing tools.
493 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
494 tool_call_auth_3
495 .response
496 .send(tool_call_auth_3.options[0].id.clone())
497 .unwrap();
498 cx.run_until_parked();
499 let completion = fake_model.pending_completions().pop().unwrap();
500 let message = completion.messages.last().unwrap();
501 assert_eq!(
502 message.content,
503 vec![language_model::MessageContent::ToolResult(
504 LanguageModelToolResult {
505 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
506 tool_name: ToolRequiringPermission::name().into(),
507 is_error: false,
508 content: "Allowed".into(),
509 output: Some("Allowed".into())
510 }
511 )]
512 );
513
514 // Simulate a final tool call, ensuring we don't trigger authorization.
515 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
516 LanguageModelToolUse {
517 id: "tool_id_4".into(),
518 name: ToolRequiringPermission::name().into(),
519 raw_input: "{}".into(),
520 input: json!({}),
521 is_input_complete: true,
522 },
523 ));
524 fake_model.end_last_completion_stream();
525 cx.run_until_parked();
526 let completion = fake_model.pending_completions().pop().unwrap();
527 let message = completion.messages.last().unwrap();
528 assert_eq!(
529 message.content,
530 vec![language_model::MessageContent::ToolResult(
531 LanguageModelToolResult {
532 tool_use_id: "tool_id_4".into(),
533 tool_name: ToolRequiringPermission::name().into(),
534 is_error: false,
535 content: "Allowed".into(),
536 output: Some("Allowed".into())
537 }
538 )]
539 );
540}
541
542#[gpui::test]
543async fn test_tool_hallucination(cx: &mut TestAppContext) {
544 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
545 let fake_model = model.as_fake();
546
547 let mut events = thread
548 .update(cx, |thread, cx| {
549 thread.send(UserMessageId::new(), ["abc"], cx)
550 })
551 .unwrap();
552 cx.run_until_parked();
553 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
554 LanguageModelToolUse {
555 id: "tool_id_1".into(),
556 name: "nonexistent_tool".into(),
557 raw_input: "{}".into(),
558 input: json!({}),
559 is_input_complete: true,
560 },
561 ));
562 fake_model.end_last_completion_stream();
563
564 let tool_call = expect_tool_call(&mut events).await;
565 assert_eq!(tool_call.title, "nonexistent_tool");
566 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
567 let update = expect_tool_call_update_fields(&mut events).await;
568 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
569}
570
571#[gpui::test]
572async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
573 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
574 let fake_model = model.as_fake();
575
576 let events = thread
577 .update(cx, |thread, cx| {
578 thread.add_tool(EchoTool);
579 thread.send(UserMessageId::new(), ["abc"], cx)
580 })
581 .unwrap();
582 cx.run_until_parked();
583 let tool_use = LanguageModelToolUse {
584 id: "tool_id_1".into(),
585 name: EchoTool::name().into(),
586 raw_input: "{}".into(),
587 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
588 is_input_complete: true,
589 };
590 fake_model
591 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
592 fake_model.end_last_completion_stream();
593
594 cx.run_until_parked();
595 let completion = fake_model.pending_completions().pop().unwrap();
596 let tool_result = LanguageModelToolResult {
597 tool_use_id: "tool_id_1".into(),
598 tool_name: EchoTool::name().into(),
599 is_error: false,
600 content: "def".into(),
601 output: Some("def".into()),
602 };
603 assert_eq!(
604 completion.messages[1..],
605 vec![
606 LanguageModelRequestMessage {
607 role: Role::User,
608 content: vec!["abc".into()],
609 cache: false
610 },
611 LanguageModelRequestMessage {
612 role: Role::Assistant,
613 content: vec![MessageContent::ToolUse(tool_use.clone())],
614 cache: false
615 },
616 LanguageModelRequestMessage {
617 role: Role::User,
618 content: vec![MessageContent::ToolResult(tool_result.clone())],
619 cache: true
620 },
621 ]
622 );
623
624 // Simulate reaching tool use limit.
625 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
626 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
627 ));
628 fake_model.end_last_completion_stream();
629 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
630 assert!(
631 last_event
632 .unwrap_err()
633 .is::<language_model::ToolUseLimitReachedError>()
634 );
635
636 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
637 cx.run_until_parked();
638 let completion = fake_model.pending_completions().pop().unwrap();
639 assert_eq!(
640 completion.messages[1..],
641 vec![
642 LanguageModelRequestMessage {
643 role: Role::User,
644 content: vec!["abc".into()],
645 cache: false
646 },
647 LanguageModelRequestMessage {
648 role: Role::Assistant,
649 content: vec![MessageContent::ToolUse(tool_use)],
650 cache: false
651 },
652 LanguageModelRequestMessage {
653 role: Role::User,
654 content: vec![MessageContent::ToolResult(tool_result)],
655 cache: false
656 },
657 LanguageModelRequestMessage {
658 role: Role::User,
659 content: vec!["Continue where you left off".into()],
660 cache: true
661 }
662 ]
663 );
664
665 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
666 fake_model.end_last_completion_stream();
667 events.collect::<Vec<_>>().await;
668 thread.read_with(cx, |thread, _cx| {
669 assert_eq!(
670 thread.last_message().unwrap().to_markdown(),
671 indoc! {"
672 ## Assistant
673
674 Done
675 "}
676 )
677 });
678}
679
680#[gpui::test]
681async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
682 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
683 let fake_model = model.as_fake();
684
685 let events = thread
686 .update(cx, |thread, cx| {
687 thread.add_tool(EchoTool);
688 thread.send(UserMessageId::new(), ["abc"], cx)
689 })
690 .unwrap();
691 cx.run_until_parked();
692
693 let tool_use = LanguageModelToolUse {
694 id: "tool_id_1".into(),
695 name: EchoTool::name().into(),
696 raw_input: "{}".into(),
697 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
698 is_input_complete: true,
699 };
700 let tool_result = LanguageModelToolResult {
701 tool_use_id: "tool_id_1".into(),
702 tool_name: EchoTool::name().into(),
703 is_error: false,
704 content: "def".into(),
705 output: Some("def".into()),
706 };
707 fake_model
708 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
709 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
710 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
711 ));
712 fake_model.end_last_completion_stream();
713 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
714 assert!(
715 last_event
716 .unwrap_err()
717 .is::<language_model::ToolUseLimitReachedError>()
718 );
719
720 thread
721 .update(cx, |thread, cx| {
722 thread.send(UserMessageId::new(), vec!["ghi"], cx)
723 })
724 .unwrap();
725 cx.run_until_parked();
726 let completion = fake_model.pending_completions().pop().unwrap();
727 assert_eq!(
728 completion.messages[1..],
729 vec![
730 LanguageModelRequestMessage {
731 role: Role::User,
732 content: vec!["abc".into()],
733 cache: false
734 },
735 LanguageModelRequestMessage {
736 role: Role::Assistant,
737 content: vec![MessageContent::ToolUse(tool_use)],
738 cache: false
739 },
740 LanguageModelRequestMessage {
741 role: Role::User,
742 content: vec![MessageContent::ToolResult(tool_result)],
743 cache: false
744 },
745 LanguageModelRequestMessage {
746 role: Role::User,
747 content: vec!["ghi".into()],
748 cache: true
749 }
750 ]
751 );
752}
753
754async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
755 let event = events
756 .next()
757 .await
758 .expect("no tool call authorization event received")
759 .unwrap();
760 match event {
761 ThreadEvent::ToolCall(tool_call) => tool_call,
762 event => {
763 panic!("Unexpected event {event:?}");
764 }
765 }
766}
767
768async fn expect_tool_call_update_fields(
769 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
770) -> acp::ToolCallUpdate {
771 let event = events
772 .next()
773 .await
774 .expect("no tool call authorization event received")
775 .unwrap();
776 match event {
777 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
778 event => {
779 panic!("Unexpected event {event:?}");
780 }
781 }
782}
783
784async fn next_tool_call_authorization(
785 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
786) -> ToolCallAuthorization {
787 loop {
788 let event = events
789 .next()
790 .await
791 .expect("no tool call authorization event received")
792 .unwrap();
793 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
794 let permission_kinds = tool_call_authorization
795 .options
796 .iter()
797 .map(|o| o.kind)
798 .collect::<Vec<_>>();
799 assert_eq!(
800 permission_kinds,
801 vec![
802 acp::PermissionOptionKind::AllowAlways,
803 acp::PermissionOptionKind::AllowOnce,
804 acp::PermissionOptionKind::RejectOnce,
805 ]
806 );
807 return tool_call_authorization;
808 }
809 }
810}
811
812#[gpui::test]
813#[cfg_attr(not(feature = "e2e"), ignore)]
814async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
815 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
816
817 // Test concurrent tool calls with different delay times
818 let events = thread
819 .update(cx, |thread, cx| {
820 thread.add_tool(DelayTool);
821 thread.send(
822 UserMessageId::new(),
823 [
824 "Call the delay tool twice in the same message.",
825 "Once with 100ms. Once with 300ms.",
826 "When both timers are complete, describe the outputs.",
827 ],
828 cx,
829 )
830 })
831 .unwrap()
832 .collect()
833 .await;
834
835 let stop_reasons = stop_events(events);
836 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
837
838 thread.update(cx, |thread, _cx| {
839 let last_message = thread.last_message().unwrap();
840 let agent_message = last_message.as_agent_message().unwrap();
841 let text = agent_message
842 .content
843 .iter()
844 .filter_map(|content| {
845 if let AgentMessageContent::Text(text) = content {
846 Some(text.as_str())
847 } else {
848 None
849 }
850 })
851 .collect::<String>();
852
853 assert!(text.contains("Ding"));
854 });
855}
856
857#[gpui::test]
858async fn test_profiles(cx: &mut TestAppContext) {
859 let ThreadTest {
860 model, thread, fs, ..
861 } = setup(cx, TestModel::Fake).await;
862 let fake_model = model.as_fake();
863
864 thread.update(cx, |thread, _cx| {
865 thread.add_tool(DelayTool);
866 thread.add_tool(EchoTool);
867 thread.add_tool(InfiniteTool);
868 });
869
870 // Override profiles and wait for settings to be loaded.
871 fs.insert_file(
872 paths::settings_file(),
873 json!({
874 "agent": {
875 "profiles": {
876 "test-1": {
877 "name": "Test Profile 1",
878 "tools": {
879 EchoTool::name(): true,
880 DelayTool::name(): true,
881 }
882 },
883 "test-2": {
884 "name": "Test Profile 2",
885 "tools": {
886 InfiniteTool::name(): true,
887 }
888 }
889 }
890 }
891 })
892 .to_string()
893 .into_bytes(),
894 )
895 .await;
896 cx.run_until_parked();
897
898 // Test that test-1 profile (default) has echo and delay tools
899 thread
900 .update(cx, |thread, cx| {
901 thread.set_profile(AgentProfileId("test-1".into()));
902 thread.send(UserMessageId::new(), ["test"], cx)
903 })
904 .unwrap();
905 cx.run_until_parked();
906
907 let mut pending_completions = fake_model.pending_completions();
908 assert_eq!(pending_completions.len(), 1);
909 let completion = pending_completions.pop().unwrap();
910 let tool_names: Vec<String> = completion
911 .tools
912 .iter()
913 .map(|tool| tool.name.clone())
914 .collect();
915 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
916 fake_model.end_last_completion_stream();
917
918 // Switch to test-2 profile, and verify that it has only the infinite tool.
919 thread
920 .update(cx, |thread, cx| {
921 thread.set_profile(AgentProfileId("test-2".into()));
922 thread.send(UserMessageId::new(), ["test2"], cx)
923 })
924 .unwrap();
925 cx.run_until_parked();
926 let mut pending_completions = fake_model.pending_completions();
927 assert_eq!(pending_completions.len(), 1);
928 let completion = pending_completions.pop().unwrap();
929 let tool_names: Vec<String> = completion
930 .tools
931 .iter()
932 .map(|tool| tool.name.clone())
933 .collect();
934 assert_eq!(tool_names, vec![InfiniteTool::name()]);
935}
936
937#[gpui::test]
938async fn test_mcp_tools(cx: &mut TestAppContext) {
939 let ThreadTest {
940 model,
941 thread,
942 context_server_store,
943 fs,
944 ..
945 } = setup(cx, TestModel::Fake).await;
946 let fake_model = model.as_fake();
947
948 // Override profiles and wait for settings to be loaded.
949 fs.insert_file(
950 paths::settings_file(),
951 json!({
952 "agent": {
953 "always_allow_tool_actions": true,
954 "profiles": {
955 "test": {
956 "name": "Test Profile",
957 "enable_all_context_servers": true,
958 "tools": {
959 EchoTool::name(): true,
960 }
961 },
962 }
963 }
964 })
965 .to_string()
966 .into_bytes(),
967 )
968 .await;
969 cx.run_until_parked();
970 thread.update(cx, |thread, _| {
971 thread.set_profile(AgentProfileId("test".into()))
972 });
973
974 let mut mcp_tool_calls = setup_context_server(
975 "test_server",
976 vec![context_server::types::Tool {
977 name: "echo".into(),
978 description: None,
979 input_schema: serde_json::to_value(
980 EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
981 )
982 .unwrap(),
983 output_schema: None,
984 annotations: None,
985 }],
986 &context_server_store,
987 cx,
988 );
989
990 let events = thread.update(cx, |thread, cx| {
991 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
992 });
993 cx.run_until_parked();
994
995 // Simulate the model calling the MCP tool.
996 let completion = fake_model.pending_completions().pop().unwrap();
997 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
998 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
999 LanguageModelToolUse {
1000 id: "tool_1".into(),
1001 name: "echo".into(),
1002 raw_input: json!({"text": "test"}).to_string(),
1003 input: json!({"text": "test"}),
1004 is_input_complete: true,
1005 },
1006 ));
1007 fake_model.end_last_completion_stream();
1008 cx.run_until_parked();
1009
1010 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1011 assert_eq!(tool_call_params.name, "echo");
1012 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1013 tool_call_response
1014 .send(context_server::types::CallToolResponse {
1015 content: vec![context_server::types::ToolResponseContent::Text {
1016 text: "test".into(),
1017 }],
1018 is_error: None,
1019 meta: None,
1020 structured_content: None,
1021 })
1022 .unwrap();
1023 cx.run_until_parked();
1024
1025 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1026 fake_model.send_last_completion_stream_text_chunk("Done!");
1027 fake_model.end_last_completion_stream();
1028 events.collect::<Vec<_>>().await;
1029
1030 // Send again after adding the echo tool, ensuring the name collision is resolved.
1031 let events = thread.update(cx, |thread, cx| {
1032 thread.add_tool(EchoTool);
1033 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1034 });
1035 cx.run_until_parked();
1036 let completion = fake_model.pending_completions().pop().unwrap();
1037 assert_eq!(
1038 tool_names_for_completion(&completion),
1039 vec!["echo", "test_server_echo"]
1040 );
1041 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1042 LanguageModelToolUse {
1043 id: "tool_2".into(),
1044 name: "test_server_echo".into(),
1045 raw_input: json!({"text": "mcp"}).to_string(),
1046 input: json!({"text": "mcp"}),
1047 is_input_complete: true,
1048 },
1049 ));
1050 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1051 LanguageModelToolUse {
1052 id: "tool_3".into(),
1053 name: "echo".into(),
1054 raw_input: json!({"text": "native"}).to_string(),
1055 input: json!({"text": "native"}),
1056 is_input_complete: true,
1057 },
1058 ));
1059 fake_model.end_last_completion_stream();
1060 cx.run_until_parked();
1061
1062 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1063 assert_eq!(tool_call_params.name, "echo");
1064 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1065 tool_call_response
1066 .send(context_server::types::CallToolResponse {
1067 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1068 is_error: None,
1069 meta: None,
1070 structured_content: None,
1071 })
1072 .unwrap();
1073 cx.run_until_parked();
1074
1075 // Ensure the tool results were inserted with the correct names.
1076 let completion = fake_model.pending_completions().pop().unwrap();
1077 assert_eq!(
1078 completion.messages.last().unwrap().content,
1079 vec![
1080 MessageContent::ToolResult(LanguageModelToolResult {
1081 tool_use_id: "tool_3".into(),
1082 tool_name: "echo".into(),
1083 is_error: false,
1084 content: "native".into(),
1085 output: Some("native".into()),
1086 },),
1087 MessageContent::ToolResult(LanguageModelToolResult {
1088 tool_use_id: "tool_2".into(),
1089 tool_name: "test_server_echo".into(),
1090 is_error: false,
1091 content: "mcp".into(),
1092 output: Some("mcp".into()),
1093 },),
1094 ]
1095 );
1096 fake_model.end_last_completion_stream();
1097 events.collect::<Vec<_>>().await;
1098}
1099
1100#[gpui::test]
1101async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1102 let ThreadTest {
1103 model,
1104 thread,
1105 context_server_store,
1106 fs,
1107 ..
1108 } = setup(cx, TestModel::Fake).await;
1109 let fake_model = model.as_fake();
1110
1111 // Set up a profile with all tools enabled
1112 fs.insert_file(
1113 paths::settings_file(),
1114 json!({
1115 "agent": {
1116 "profiles": {
1117 "test": {
1118 "name": "Test Profile",
1119 "enable_all_context_servers": true,
1120 "tools": {
1121 EchoTool::name(): true,
1122 DelayTool::name(): true,
1123 WordListTool::name(): true,
1124 ToolRequiringPermission::name(): true,
1125 InfiniteTool::name(): true,
1126 }
1127 },
1128 }
1129 }
1130 })
1131 .to_string()
1132 .into_bytes(),
1133 )
1134 .await;
1135 cx.run_until_parked();
1136
1137 thread.update(cx, |thread, _| {
1138 thread.set_profile(AgentProfileId("test".into()));
1139 thread.add_tool(EchoTool);
1140 thread.add_tool(DelayTool);
1141 thread.add_tool(WordListTool);
1142 thread.add_tool(ToolRequiringPermission);
1143 thread.add_tool(InfiniteTool);
1144 });
1145
1146 // Set up multiple context servers with some overlapping tool names
1147 let _server1_calls = setup_context_server(
1148 "xxx",
1149 vec![
1150 context_server::types::Tool {
1151 name: "echo".into(), // Conflicts with native EchoTool
1152 description: None,
1153 input_schema: serde_json::to_value(
1154 EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1155 )
1156 .unwrap(),
1157 output_schema: None,
1158 annotations: None,
1159 },
1160 context_server::types::Tool {
1161 name: "unique_tool_1".into(),
1162 description: None,
1163 input_schema: json!({"type": "object", "properties": {}}),
1164 output_schema: None,
1165 annotations: None,
1166 },
1167 ],
1168 &context_server_store,
1169 cx,
1170 );
1171
1172 let _server2_calls = setup_context_server(
1173 "yyy",
1174 vec![
1175 context_server::types::Tool {
1176 name: "echo".into(), // Also conflicts with native EchoTool
1177 description: None,
1178 input_schema: serde_json::to_value(
1179 EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1180 )
1181 .unwrap(),
1182 output_schema: None,
1183 annotations: None,
1184 },
1185 context_server::types::Tool {
1186 name: "unique_tool_2".into(),
1187 description: None,
1188 input_schema: json!({"type": "object", "properties": {}}),
1189 output_schema: None,
1190 annotations: None,
1191 },
1192 context_server::types::Tool {
1193 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1194 description: None,
1195 input_schema: json!({"type": "object", "properties": {}}),
1196 output_schema: None,
1197 annotations: None,
1198 },
1199 context_server::types::Tool {
1200 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1201 description: None,
1202 input_schema: json!({"type": "object", "properties": {}}),
1203 output_schema: None,
1204 annotations: None,
1205 },
1206 ],
1207 &context_server_store,
1208 cx,
1209 );
1210 let _server3_calls = setup_context_server(
1211 "zzz",
1212 vec![
1213 context_server::types::Tool {
1214 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1215 description: None,
1216 input_schema: json!({"type": "object", "properties": {}}),
1217 output_schema: None,
1218 annotations: None,
1219 },
1220 context_server::types::Tool {
1221 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1222 description: None,
1223 input_schema: json!({"type": "object", "properties": {}}),
1224 output_schema: None,
1225 annotations: None,
1226 },
1227 context_server::types::Tool {
1228 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1229 description: None,
1230 input_schema: json!({"type": "object", "properties": {}}),
1231 output_schema: None,
1232 annotations: None,
1233 },
1234 ],
1235 &context_server_store,
1236 cx,
1237 );
1238
1239 thread
1240 .update(cx, |thread, cx| {
1241 thread.send(UserMessageId::new(), ["Go"], cx)
1242 })
1243 .unwrap();
1244 cx.run_until_parked();
1245 let completion = fake_model.pending_completions().pop().unwrap();
1246 assert_eq!(
1247 tool_names_for_completion(&completion),
1248 vec![
1249 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1250 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1251 "delay",
1252 "echo",
1253 "infinite",
1254 "tool_requiring_permission",
1255 "unique_tool_1",
1256 "unique_tool_2",
1257 "word_list",
1258 "xxx_echo",
1259 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1260 "yyy_echo",
1261 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1262 ]
1263 );
1264}
1265
1266#[gpui::test]
1267#[cfg_attr(not(feature = "e2e"), ignore)]
1268async fn test_cancellation(cx: &mut TestAppContext) {
1269 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1270
1271 let mut events = thread
1272 .update(cx, |thread, cx| {
1273 thread.add_tool(InfiniteTool);
1274 thread.add_tool(EchoTool);
1275 thread.send(
1276 UserMessageId::new(),
1277 ["Call the echo tool, then call the infinite tool, then explain their output"],
1278 cx,
1279 )
1280 })
1281 .unwrap();
1282
1283 // Wait until both tools are called.
1284 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1285 let mut echo_id = None;
1286 let mut echo_completed = false;
1287 while let Some(event) = events.next().await {
1288 match event.unwrap() {
1289 ThreadEvent::ToolCall(tool_call) => {
1290 assert_eq!(tool_call.title, expected_tools.remove(0));
1291 if tool_call.title == "Echo" {
1292 echo_id = Some(tool_call.id);
1293 }
1294 }
1295 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1296 acp::ToolCallUpdate {
1297 id,
1298 fields:
1299 acp::ToolCallUpdateFields {
1300 status: Some(acp::ToolCallStatus::Completed),
1301 ..
1302 },
1303 },
1304 )) if Some(&id) == echo_id.as_ref() => {
1305 echo_completed = true;
1306 }
1307 _ => {}
1308 }
1309
1310 if expected_tools.is_empty() && echo_completed {
1311 break;
1312 }
1313 }
1314
1315 // Cancel the current send and ensure that the event stream is closed, even
1316 // if one of the tools is still running.
1317 thread.update(cx, |thread, cx| thread.cancel(cx));
1318 let events = events.collect::<Vec<_>>().await;
1319 let last_event = events.last();
1320 assert!(
1321 matches!(
1322 last_event,
1323 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1324 ),
1325 "unexpected event {last_event:?}"
1326 );
1327
1328 // Ensure we can still send a new message after cancellation.
1329 let events = thread
1330 .update(cx, |thread, cx| {
1331 thread.send(
1332 UserMessageId::new(),
1333 ["Testing: reply with 'Hello' then stop."],
1334 cx,
1335 )
1336 })
1337 .unwrap()
1338 .collect::<Vec<_>>()
1339 .await;
1340 thread.update(cx, |thread, _cx| {
1341 let message = thread.last_message().unwrap();
1342 let agent_message = message.as_agent_message().unwrap();
1343 assert_eq!(
1344 agent_message.content,
1345 vec![AgentMessageContent::Text("Hello".to_string())]
1346 );
1347 });
1348 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1349}
1350
1351#[gpui::test]
1352#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1353async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1354 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1355 let fake_model = model.as_fake();
1356
1357 let events_1 = thread
1358 .update(cx, |thread, cx| {
1359 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1360 })
1361 .unwrap();
1362 cx.run_until_parked();
1363 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1364 cx.run_until_parked();
1365
1366 let events_2 = thread
1367 .update(cx, |thread, cx| {
1368 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1369 })
1370 .unwrap();
1371 cx.run_until_parked();
1372 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1373 fake_model
1374 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1375 fake_model.end_last_completion_stream();
1376
1377 let events_1 = events_1.collect::<Vec<_>>().await;
1378 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1379 let events_2 = events_2.collect::<Vec<_>>().await;
1380 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1381}
1382
1383#[gpui::test]
1384async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1385 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1386 let fake_model = model.as_fake();
1387
1388 let events_1 = thread
1389 .update(cx, |thread, cx| {
1390 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1391 })
1392 .unwrap();
1393 cx.run_until_parked();
1394 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1395 fake_model
1396 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1397 fake_model.end_last_completion_stream();
1398 let events_1 = events_1.collect::<Vec<_>>().await;
1399
1400 let events_2 = thread
1401 .update(cx, |thread, cx| {
1402 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1403 })
1404 .unwrap();
1405 cx.run_until_parked();
1406 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1407 fake_model
1408 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1409 fake_model.end_last_completion_stream();
1410 let events_2 = events_2.collect::<Vec<_>>().await;
1411
1412 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1413 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1414}
1415
1416#[gpui::test]
1417async fn test_refusal(cx: &mut TestAppContext) {
1418 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1419 let fake_model = model.as_fake();
1420
1421 let events = thread
1422 .update(cx, |thread, cx| {
1423 thread.send(UserMessageId::new(), ["Hello"], cx)
1424 })
1425 .unwrap();
1426 cx.run_until_parked();
1427 thread.read_with(cx, |thread, _| {
1428 assert_eq!(
1429 thread.to_markdown(),
1430 indoc! {"
1431 ## User
1432
1433 Hello
1434 "}
1435 );
1436 });
1437
1438 fake_model.send_last_completion_stream_text_chunk("Hey!");
1439 cx.run_until_parked();
1440 thread.read_with(cx, |thread, _| {
1441 assert_eq!(
1442 thread.to_markdown(),
1443 indoc! {"
1444 ## User
1445
1446 Hello
1447
1448 ## Assistant
1449
1450 Hey!
1451 "}
1452 );
1453 });
1454
1455 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1456 fake_model
1457 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1458 let events = events.collect::<Vec<_>>().await;
1459 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1460 thread.read_with(cx, |thread, _| {
1461 assert_eq!(thread.to_markdown(), "");
1462 });
1463}
1464
1465#[gpui::test]
1466async fn test_truncate_first_message(cx: &mut TestAppContext) {
1467 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1468 let fake_model = model.as_fake();
1469
1470 let message_id = UserMessageId::new();
1471 thread
1472 .update(cx, |thread, cx| {
1473 thread.send(message_id.clone(), ["Hello"], cx)
1474 })
1475 .unwrap();
1476 cx.run_until_parked();
1477 thread.read_with(cx, |thread, _| {
1478 assert_eq!(
1479 thread.to_markdown(),
1480 indoc! {"
1481 ## User
1482
1483 Hello
1484 "}
1485 );
1486 assert_eq!(thread.latest_token_usage(), None);
1487 });
1488
1489 fake_model.send_last_completion_stream_text_chunk("Hey!");
1490 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1491 language_model::TokenUsage {
1492 input_tokens: 32_000,
1493 output_tokens: 16_000,
1494 cache_creation_input_tokens: 0,
1495 cache_read_input_tokens: 0,
1496 },
1497 ));
1498 cx.run_until_parked();
1499 thread.read_with(cx, |thread, _| {
1500 assert_eq!(
1501 thread.to_markdown(),
1502 indoc! {"
1503 ## User
1504
1505 Hello
1506
1507 ## Assistant
1508
1509 Hey!
1510 "}
1511 );
1512 assert_eq!(
1513 thread.latest_token_usage(),
1514 Some(acp_thread::TokenUsage {
1515 used_tokens: 32_000 + 16_000,
1516 max_tokens: 1_000_000,
1517 })
1518 );
1519 });
1520
1521 thread
1522 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1523 .unwrap();
1524 cx.run_until_parked();
1525 thread.read_with(cx, |thread, _| {
1526 assert_eq!(thread.to_markdown(), "");
1527 assert_eq!(thread.latest_token_usage(), None);
1528 });
1529
1530 // Ensure we can still send a new message after truncation.
1531 thread
1532 .update(cx, |thread, cx| {
1533 thread.send(UserMessageId::new(), ["Hi"], cx)
1534 })
1535 .unwrap();
1536 thread.update(cx, |thread, _cx| {
1537 assert_eq!(
1538 thread.to_markdown(),
1539 indoc! {"
1540 ## User
1541
1542 Hi
1543 "}
1544 );
1545 });
1546 cx.run_until_parked();
1547 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1548 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1549 language_model::TokenUsage {
1550 input_tokens: 40_000,
1551 output_tokens: 20_000,
1552 cache_creation_input_tokens: 0,
1553 cache_read_input_tokens: 0,
1554 },
1555 ));
1556 cx.run_until_parked();
1557 thread.read_with(cx, |thread, _| {
1558 assert_eq!(
1559 thread.to_markdown(),
1560 indoc! {"
1561 ## User
1562
1563 Hi
1564
1565 ## Assistant
1566
1567 Ahoy!
1568 "}
1569 );
1570
1571 assert_eq!(
1572 thread.latest_token_usage(),
1573 Some(acp_thread::TokenUsage {
1574 used_tokens: 40_000 + 20_000,
1575 max_tokens: 1_000_000,
1576 })
1577 );
1578 });
1579}
1580
1581#[gpui::test]
1582async fn test_truncate_second_message(cx: &mut TestAppContext) {
1583 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1584 let fake_model = model.as_fake();
1585
1586 thread
1587 .update(cx, |thread, cx| {
1588 thread.send(UserMessageId::new(), ["Message 1"], cx)
1589 })
1590 .unwrap();
1591 cx.run_until_parked();
1592 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1593 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1594 language_model::TokenUsage {
1595 input_tokens: 32_000,
1596 output_tokens: 16_000,
1597 cache_creation_input_tokens: 0,
1598 cache_read_input_tokens: 0,
1599 },
1600 ));
1601 fake_model.end_last_completion_stream();
1602 cx.run_until_parked();
1603
1604 let assert_first_message_state = |cx: &mut TestAppContext| {
1605 thread.clone().read_with(cx, |thread, _| {
1606 assert_eq!(
1607 thread.to_markdown(),
1608 indoc! {"
1609 ## User
1610
1611 Message 1
1612
1613 ## Assistant
1614
1615 Message 1 response
1616 "}
1617 );
1618
1619 assert_eq!(
1620 thread.latest_token_usage(),
1621 Some(acp_thread::TokenUsage {
1622 used_tokens: 32_000 + 16_000,
1623 max_tokens: 1_000_000,
1624 })
1625 );
1626 });
1627 };
1628
1629 assert_first_message_state(cx);
1630
1631 let second_message_id = UserMessageId::new();
1632 thread
1633 .update(cx, |thread, cx| {
1634 thread.send(second_message_id.clone(), ["Message 2"], cx)
1635 })
1636 .unwrap();
1637 cx.run_until_parked();
1638
1639 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1640 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1641 language_model::TokenUsage {
1642 input_tokens: 40_000,
1643 output_tokens: 20_000,
1644 cache_creation_input_tokens: 0,
1645 cache_read_input_tokens: 0,
1646 },
1647 ));
1648 fake_model.end_last_completion_stream();
1649 cx.run_until_parked();
1650
1651 thread.read_with(cx, |thread, _| {
1652 assert_eq!(
1653 thread.to_markdown(),
1654 indoc! {"
1655 ## User
1656
1657 Message 1
1658
1659 ## Assistant
1660
1661 Message 1 response
1662
1663 ## User
1664
1665 Message 2
1666
1667 ## Assistant
1668
1669 Message 2 response
1670 "}
1671 );
1672
1673 assert_eq!(
1674 thread.latest_token_usage(),
1675 Some(acp_thread::TokenUsage {
1676 used_tokens: 40_000 + 20_000,
1677 max_tokens: 1_000_000,
1678 })
1679 );
1680 });
1681
1682 thread
1683 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1684 .unwrap();
1685 cx.run_until_parked();
1686
1687 assert_first_message_state(cx);
1688}
1689
1690#[gpui::test]
1691#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1692async fn test_title_generation(cx: &mut TestAppContext) {
1693 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1694 let fake_model = model.as_fake();
1695
1696 let summary_model = Arc::new(FakeLanguageModel::default());
1697 thread.update(cx, |thread, cx| {
1698 thread.set_summarization_model(Some(summary_model.clone()), cx)
1699 });
1700
1701 let send = thread
1702 .update(cx, |thread, cx| {
1703 thread.send(UserMessageId::new(), ["Hello"], cx)
1704 })
1705 .unwrap();
1706 cx.run_until_parked();
1707
1708 fake_model.send_last_completion_stream_text_chunk("Hey!");
1709 fake_model.end_last_completion_stream();
1710 cx.run_until_parked();
1711 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1712
1713 // Ensure the summary model has been invoked to generate a title.
1714 summary_model.send_last_completion_stream_text_chunk("Hello ");
1715 summary_model.send_last_completion_stream_text_chunk("world\nG");
1716 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1717 summary_model.end_last_completion_stream();
1718 send.collect::<Vec<_>>().await;
1719 cx.run_until_parked();
1720 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1721
1722 // Send another message, ensuring no title is generated this time.
1723 let send = thread
1724 .update(cx, |thread, cx| {
1725 thread.send(UserMessageId::new(), ["Hello again"], cx)
1726 })
1727 .unwrap();
1728 cx.run_until_parked();
1729 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1730 fake_model.end_last_completion_stream();
1731 cx.run_until_parked();
1732 assert_eq!(summary_model.pending_completions(), Vec::new());
1733 send.collect::<Vec<_>>().await;
1734 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1735}
1736
1737#[gpui::test]
1738async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1739 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1740 let fake_model = model.as_fake();
1741
1742 let _events = thread
1743 .update(cx, |thread, cx| {
1744 thread.add_tool(ToolRequiringPermission);
1745 thread.add_tool(EchoTool);
1746 thread.send(UserMessageId::new(), ["Hey!"], cx)
1747 })
1748 .unwrap();
1749 cx.run_until_parked();
1750
1751 let permission_tool_use = LanguageModelToolUse {
1752 id: "tool_id_1".into(),
1753 name: ToolRequiringPermission::name().into(),
1754 raw_input: "{}".into(),
1755 input: json!({}),
1756 is_input_complete: true,
1757 };
1758 let echo_tool_use = LanguageModelToolUse {
1759 id: "tool_id_2".into(),
1760 name: EchoTool::name().into(),
1761 raw_input: json!({"text": "test"}).to_string(),
1762 input: json!({"text": "test"}),
1763 is_input_complete: true,
1764 };
1765 fake_model.send_last_completion_stream_text_chunk("Hi!");
1766 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1767 permission_tool_use,
1768 ));
1769 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1770 echo_tool_use.clone(),
1771 ));
1772 fake_model.end_last_completion_stream();
1773 cx.run_until_parked();
1774
1775 // Ensure pending tools are skipped when building a request.
1776 let request = thread
1777 .read_with(cx, |thread, cx| {
1778 thread.build_completion_request(CompletionIntent::EditFile, cx)
1779 })
1780 .unwrap();
1781 assert_eq!(
1782 request.messages[1..],
1783 vec![
1784 LanguageModelRequestMessage {
1785 role: Role::User,
1786 content: vec!["Hey!".into()],
1787 cache: true
1788 },
1789 LanguageModelRequestMessage {
1790 role: Role::Assistant,
1791 content: vec![
1792 MessageContent::Text("Hi!".into()),
1793 MessageContent::ToolUse(echo_tool_use.clone())
1794 ],
1795 cache: false
1796 },
1797 LanguageModelRequestMessage {
1798 role: Role::User,
1799 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1800 tool_use_id: echo_tool_use.id.clone(),
1801 tool_name: echo_tool_use.name,
1802 is_error: false,
1803 content: "test".into(),
1804 output: Some("test".into())
1805 })],
1806 cache: false
1807 },
1808 ],
1809 );
1810}
1811
1812#[gpui::test]
1813async fn test_agent_connection(cx: &mut TestAppContext) {
1814 cx.update(settings::init);
1815 let templates = Templates::new();
1816
1817 // Initialize language model system with test provider
1818 cx.update(|cx| {
1819 gpui_tokio::init(cx);
1820 client::init_settings(cx);
1821
1822 let http_client = FakeHttpClient::with_404_response();
1823 let clock = Arc::new(clock::FakeSystemClock::new());
1824 let client = Client::new(clock, http_client, cx);
1825 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1826 language_model::init(client.clone(), cx);
1827 language_models::init(user_store, client.clone(), cx);
1828 Project::init_settings(cx);
1829 LanguageModelRegistry::test(cx);
1830 agent_settings::init(cx);
1831 });
1832 cx.executor().forbid_parking();
1833
1834 // Create a project for new_thread
1835 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1836 fake_fs.insert_tree(path!("/test"), json!({})).await;
1837 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1838 let cwd = Path::new("/test");
1839 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1840 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1841
1842 // Create agent and connection
1843 let agent = NativeAgent::new(
1844 project.clone(),
1845 history_store,
1846 templates.clone(),
1847 None,
1848 fake_fs.clone(),
1849 &mut cx.to_async(),
1850 )
1851 .await
1852 .unwrap();
1853 let connection = NativeAgentConnection(agent.clone());
1854
1855 // Test model_selector returns Some
1856 let selector_opt = connection.model_selector();
1857 assert!(
1858 selector_opt.is_some(),
1859 "agent2 should always support ModelSelector"
1860 );
1861 let selector = selector_opt.unwrap();
1862
1863 // Test list_models
1864 let listed_models = cx
1865 .update(|cx| selector.list_models(cx))
1866 .await
1867 .expect("list_models should succeed");
1868 let AgentModelList::Grouped(listed_models) = listed_models else {
1869 panic!("Unexpected model list type");
1870 };
1871 assert!(!listed_models.is_empty(), "should have at least one model");
1872 assert_eq!(
1873 listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1874 "fake/fake"
1875 );
1876
1877 // Create a thread using new_thread
1878 let connection_rc = Rc::new(connection.clone());
1879 let acp_thread = cx
1880 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1881 .await
1882 .expect("new_thread should succeed");
1883
1884 // Get the session_id from the AcpThread
1885 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1886
1887 // Test selected_model returns the default
1888 let model = cx
1889 .update(|cx| selector.selected_model(&session_id, cx))
1890 .await
1891 .expect("selected_model should succeed");
1892 let model = cx
1893 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1894 .unwrap();
1895 let model = model.as_fake();
1896 assert_eq!(model.id().0, "fake", "should return default model");
1897
1898 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1899 cx.run_until_parked();
1900 model.send_last_completion_stream_text_chunk("def");
1901 cx.run_until_parked();
1902 acp_thread.read_with(cx, |thread, cx| {
1903 assert_eq!(
1904 thread.to_markdown(cx),
1905 indoc! {"
1906 ## User
1907
1908 abc
1909
1910 ## Assistant
1911
1912 def
1913
1914 "}
1915 )
1916 });
1917
1918 // Test cancel
1919 cx.update(|cx| connection.cancel(&session_id, cx));
1920 request.await.expect("prompt should fail gracefully");
1921
1922 // Ensure that dropping the ACP thread causes the native thread to be
1923 // dropped as well.
1924 cx.update(|_| drop(acp_thread));
1925 let result = cx
1926 .update(|cx| {
1927 connection.prompt(
1928 Some(acp_thread::UserMessageId::new()),
1929 acp::PromptRequest {
1930 session_id: session_id.clone(),
1931 prompt: vec!["ghi".into()],
1932 },
1933 cx,
1934 )
1935 })
1936 .await;
1937 assert_eq!(
1938 result.as_ref().unwrap_err().to_string(),
1939 "Session not found",
1940 "unexpected result: {:?}",
1941 result
1942 );
1943}
1944
1945#[gpui::test]
1946async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1947 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1948 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1949 let fake_model = model.as_fake();
1950
1951 let mut events = thread
1952 .update(cx, |thread, cx| {
1953 thread.send(UserMessageId::new(), ["Think"], cx)
1954 })
1955 .unwrap();
1956 cx.run_until_parked();
1957
1958 // Simulate streaming partial input.
1959 let input = json!({});
1960 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1961 LanguageModelToolUse {
1962 id: "1".into(),
1963 name: ThinkingTool::name().into(),
1964 raw_input: input.to_string(),
1965 input,
1966 is_input_complete: false,
1967 },
1968 ));
1969
1970 // Input streaming completed
1971 let input = json!({ "content": "Thinking hard!" });
1972 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1973 LanguageModelToolUse {
1974 id: "1".into(),
1975 name: "thinking".into(),
1976 raw_input: input.to_string(),
1977 input,
1978 is_input_complete: true,
1979 },
1980 ));
1981 fake_model.end_last_completion_stream();
1982 cx.run_until_parked();
1983
1984 let tool_call = expect_tool_call(&mut events).await;
1985 assert_eq!(
1986 tool_call,
1987 acp::ToolCall {
1988 id: acp::ToolCallId("1".into()),
1989 title: "Thinking".into(),
1990 kind: acp::ToolKind::Think,
1991 status: acp::ToolCallStatus::Pending,
1992 content: vec![],
1993 locations: vec![],
1994 raw_input: Some(json!({})),
1995 raw_output: None,
1996 }
1997 );
1998 let update = expect_tool_call_update_fields(&mut events).await;
1999 assert_eq!(
2000 update,
2001 acp::ToolCallUpdate {
2002 id: acp::ToolCallId("1".into()),
2003 fields: acp::ToolCallUpdateFields {
2004 title: Some("Thinking".into()),
2005 kind: Some(acp::ToolKind::Think),
2006 raw_input: Some(json!({ "content": "Thinking hard!" })),
2007 ..Default::default()
2008 },
2009 }
2010 );
2011 let update = expect_tool_call_update_fields(&mut events).await;
2012 assert_eq!(
2013 update,
2014 acp::ToolCallUpdate {
2015 id: acp::ToolCallId("1".into()),
2016 fields: acp::ToolCallUpdateFields {
2017 status: Some(acp::ToolCallStatus::InProgress),
2018 ..Default::default()
2019 },
2020 }
2021 );
2022 let update = expect_tool_call_update_fields(&mut events).await;
2023 assert_eq!(
2024 update,
2025 acp::ToolCallUpdate {
2026 id: acp::ToolCallId("1".into()),
2027 fields: acp::ToolCallUpdateFields {
2028 content: Some(vec!["Thinking hard!".into()]),
2029 ..Default::default()
2030 },
2031 }
2032 );
2033 let update = expect_tool_call_update_fields(&mut events).await;
2034 assert_eq!(
2035 update,
2036 acp::ToolCallUpdate {
2037 id: acp::ToolCallId("1".into()),
2038 fields: acp::ToolCallUpdateFields {
2039 status: Some(acp::ToolCallStatus::Completed),
2040 raw_output: Some("Finished thinking.".into()),
2041 ..Default::default()
2042 },
2043 }
2044 );
2045}
2046
2047#[gpui::test]
2048async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2049 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2050 let fake_model = model.as_fake();
2051
2052 let mut events = thread
2053 .update(cx, |thread, cx| {
2054 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2055 thread.send(UserMessageId::new(), ["Hello!"], cx)
2056 })
2057 .unwrap();
2058 cx.run_until_parked();
2059
2060 fake_model.send_last_completion_stream_text_chunk("Hey!");
2061 fake_model.end_last_completion_stream();
2062
2063 let mut retry_events = Vec::new();
2064 while let Some(Ok(event)) = events.next().await {
2065 match event {
2066 ThreadEvent::Retry(retry_status) => {
2067 retry_events.push(retry_status);
2068 }
2069 ThreadEvent::Stop(..) => break,
2070 _ => {}
2071 }
2072 }
2073
2074 assert_eq!(retry_events.len(), 0);
2075 thread.read_with(cx, |thread, _cx| {
2076 assert_eq!(
2077 thread.to_markdown(),
2078 indoc! {"
2079 ## User
2080
2081 Hello!
2082
2083 ## Assistant
2084
2085 Hey!
2086 "}
2087 )
2088 });
2089}
2090
2091#[gpui::test]
2092async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2093 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2094 let fake_model = model.as_fake();
2095
2096 let mut events = thread
2097 .update(cx, |thread, cx| {
2098 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2099 thread.send(UserMessageId::new(), ["Hello!"], cx)
2100 })
2101 .unwrap();
2102 cx.run_until_parked();
2103
2104 fake_model.send_last_completion_stream_text_chunk("Hey,");
2105 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2106 provider: LanguageModelProviderName::new("Anthropic"),
2107 retry_after: Some(Duration::from_secs(3)),
2108 });
2109 fake_model.end_last_completion_stream();
2110
2111 cx.executor().advance_clock(Duration::from_secs(3));
2112 cx.run_until_parked();
2113
2114 fake_model.send_last_completion_stream_text_chunk("there!");
2115 fake_model.end_last_completion_stream();
2116 cx.run_until_parked();
2117
2118 let mut retry_events = Vec::new();
2119 while let Some(Ok(event)) = events.next().await {
2120 match event {
2121 ThreadEvent::Retry(retry_status) => {
2122 retry_events.push(retry_status);
2123 }
2124 ThreadEvent::Stop(..) => break,
2125 _ => {}
2126 }
2127 }
2128
2129 assert_eq!(retry_events.len(), 1);
2130 assert!(matches!(
2131 retry_events[0],
2132 acp_thread::RetryStatus { attempt: 1, .. }
2133 ));
2134 thread.read_with(cx, |thread, _cx| {
2135 assert_eq!(
2136 thread.to_markdown(),
2137 indoc! {"
2138 ## User
2139
2140 Hello!
2141
2142 ## Assistant
2143
2144 Hey,
2145
2146 [resume]
2147
2148 ## Assistant
2149
2150 there!
2151 "}
2152 )
2153 });
2154}
2155
2156#[gpui::test]
2157async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2158 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2159 let fake_model = model.as_fake();
2160
2161 let events = thread
2162 .update(cx, |thread, cx| {
2163 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2164 thread.add_tool(EchoTool);
2165 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2166 })
2167 .unwrap();
2168 cx.run_until_parked();
2169
2170 let tool_use_1 = LanguageModelToolUse {
2171 id: "tool_1".into(),
2172 name: EchoTool::name().into(),
2173 raw_input: json!({"text": "test"}).to_string(),
2174 input: json!({"text": "test"}),
2175 is_input_complete: true,
2176 };
2177 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2178 tool_use_1.clone(),
2179 ));
2180 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2181 provider: LanguageModelProviderName::new("Anthropic"),
2182 retry_after: Some(Duration::from_secs(3)),
2183 });
2184 fake_model.end_last_completion_stream();
2185
2186 cx.executor().advance_clock(Duration::from_secs(3));
2187 let completion = fake_model.pending_completions().pop().unwrap();
2188 assert_eq!(
2189 completion.messages[1..],
2190 vec![
2191 LanguageModelRequestMessage {
2192 role: Role::User,
2193 content: vec!["Call the echo tool!".into()],
2194 cache: false
2195 },
2196 LanguageModelRequestMessage {
2197 role: Role::Assistant,
2198 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2199 cache: false
2200 },
2201 LanguageModelRequestMessage {
2202 role: Role::User,
2203 content: vec![language_model::MessageContent::ToolResult(
2204 LanguageModelToolResult {
2205 tool_use_id: tool_use_1.id.clone(),
2206 tool_name: tool_use_1.name.clone(),
2207 is_error: false,
2208 content: "test".into(),
2209 output: Some("test".into())
2210 }
2211 )],
2212 cache: true
2213 },
2214 ]
2215 );
2216
2217 fake_model.send_last_completion_stream_text_chunk("Done");
2218 fake_model.end_last_completion_stream();
2219 cx.run_until_parked();
2220 events.collect::<Vec<_>>().await;
2221 thread.read_with(cx, |thread, _cx| {
2222 assert_eq!(
2223 thread.last_message(),
2224 Some(Message::Agent(AgentMessage {
2225 content: vec![AgentMessageContent::Text("Done".into())],
2226 tool_results: IndexMap::default()
2227 }))
2228 );
2229 })
2230}
2231
2232#[gpui::test]
2233async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2234 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2235 let fake_model = model.as_fake();
2236
2237 let mut events = thread
2238 .update(cx, |thread, cx| {
2239 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2240 thread.send(UserMessageId::new(), ["Hello!"], cx)
2241 })
2242 .unwrap();
2243 cx.run_until_parked();
2244
2245 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2246 fake_model.send_last_completion_stream_error(
2247 LanguageModelCompletionError::ServerOverloaded {
2248 provider: LanguageModelProviderName::new("Anthropic"),
2249 retry_after: Some(Duration::from_secs(3)),
2250 },
2251 );
2252 fake_model.end_last_completion_stream();
2253 cx.executor().advance_clock(Duration::from_secs(3));
2254 cx.run_until_parked();
2255 }
2256
2257 let mut errors = Vec::new();
2258 let mut retry_events = Vec::new();
2259 while let Some(event) = events.next().await {
2260 match event {
2261 Ok(ThreadEvent::Retry(retry_status)) => {
2262 retry_events.push(retry_status);
2263 }
2264 Ok(ThreadEvent::Stop(..)) => break,
2265 Err(error) => errors.push(error),
2266 _ => {}
2267 }
2268 }
2269
2270 assert_eq!(
2271 retry_events.len(),
2272 crate::thread::MAX_RETRY_ATTEMPTS as usize
2273 );
2274 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2275 assert_eq!(retry_events[i].attempt, i + 1);
2276 }
2277 assert_eq!(errors.len(), 1);
2278 let error = errors[0]
2279 .downcast_ref::<LanguageModelCompletionError>()
2280 .unwrap();
2281 assert!(matches!(
2282 error,
2283 LanguageModelCompletionError::ServerOverloaded { .. }
2284 ));
2285}
2286
2287/// Filters out the stop events for asserting against in tests
2288fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2289 result_events
2290 .into_iter()
2291 .filter_map(|event| match event.unwrap() {
2292 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2293 _ => None,
2294 })
2295 .collect()
2296}
2297
2298struct ThreadTest {
2299 model: Arc<dyn LanguageModel>,
2300 thread: Entity<Thread>,
2301 project_context: Entity<ProjectContext>,
2302 context_server_store: Entity<ContextServerStore>,
2303 fs: Arc<FakeFs>,
2304}
2305
2306enum TestModel {
2307 Sonnet4,
2308 Fake,
2309}
2310
2311impl TestModel {
2312 fn id(&self) -> LanguageModelId {
2313 match self {
2314 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2315 TestModel::Fake => unreachable!(),
2316 }
2317 }
2318}
2319
2320async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2321 cx.executor().allow_parking();
2322
2323 let fs = FakeFs::new(cx.background_executor.clone());
2324 fs.create_dir(paths::settings_file().parent().unwrap())
2325 .await
2326 .unwrap();
2327 fs.insert_file(
2328 paths::settings_file(),
2329 json!({
2330 "agent": {
2331 "default_profile": "test-profile",
2332 "profiles": {
2333 "test-profile": {
2334 "name": "Test Profile",
2335 "tools": {
2336 EchoTool::name(): true,
2337 DelayTool::name(): true,
2338 WordListTool::name(): true,
2339 ToolRequiringPermission::name(): true,
2340 InfiniteTool::name(): true,
2341 ThinkingTool::name(): true,
2342 }
2343 }
2344 }
2345 }
2346 })
2347 .to_string()
2348 .into_bytes(),
2349 )
2350 .await;
2351
2352 cx.update(|cx| {
2353 settings::init(cx);
2354 Project::init_settings(cx);
2355 agent_settings::init(cx);
2356 gpui_tokio::init(cx);
2357 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2358 cx.set_http_client(Arc::new(http_client));
2359
2360 client::init_settings(cx);
2361 let client = Client::production(cx);
2362 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2363 language_model::init(client.clone(), cx);
2364 language_models::init(user_store, client.clone(), cx);
2365
2366 watch_settings(fs.clone(), cx);
2367 });
2368
2369 let templates = Templates::new();
2370
2371 fs.insert_tree(path!("/test"), json!({})).await;
2372 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2373
2374 let model = cx
2375 .update(|cx| {
2376 if let TestModel::Fake = model {
2377 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2378 } else {
2379 let model_id = model.id();
2380 let models = LanguageModelRegistry::read_global(cx);
2381 let model = models
2382 .available_models(cx)
2383 .find(|model| model.id() == model_id)
2384 .unwrap();
2385
2386 let provider = models.provider(&model.provider_id()).unwrap();
2387 let authenticated = provider.authenticate(cx);
2388
2389 cx.spawn(async move |_cx| {
2390 authenticated.await.unwrap();
2391 model
2392 })
2393 }
2394 })
2395 .await;
2396
2397 let project_context = cx.new(|_cx| ProjectContext::default());
2398 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2399 let context_server_registry =
2400 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2401 let thread = cx.new(|cx| {
2402 Thread::new(
2403 project,
2404 project_context.clone(),
2405 context_server_registry,
2406 templates,
2407 Some(model.clone()),
2408 cx,
2409 )
2410 });
2411 ThreadTest {
2412 model,
2413 thread,
2414 project_context,
2415 context_server_store,
2416 fs,
2417 }
2418}
2419
2420#[cfg(test)]
2421#[ctor::ctor]
2422fn init_logger() {
2423 if std::env::var("RUST_LOG").is_ok() {
2424 env_logger::init();
2425 }
2426}
2427
2428fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2429 let fs = fs.clone();
2430 cx.spawn({
2431 async move |cx| {
2432 let mut new_settings_content_rx = settings::watch_config_file(
2433 cx.background_executor(),
2434 fs,
2435 paths::settings_file().clone(),
2436 );
2437
2438 while let Some(new_settings_content) = new_settings_content_rx.next().await {
2439 cx.update(|cx| {
2440 SettingsStore::update_global(cx, |settings, cx| {
2441 settings.set_user_settings(&new_settings_content, cx)
2442 })
2443 })
2444 .ok();
2445 }
2446 }
2447 })
2448 .detach();
2449}
2450
2451fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2452 completion
2453 .tools
2454 .iter()
2455 .map(|tool| tool.name.clone())
2456 .collect()
2457}
2458
2459fn setup_context_server(
2460 name: &'static str,
2461 tools: Vec<context_server::types::Tool>,
2462 context_server_store: &Entity<ContextServerStore>,
2463 cx: &mut TestAppContext,
2464) -> mpsc::UnboundedReceiver<(
2465 context_server::types::CallToolParams,
2466 oneshot::Sender<context_server::types::CallToolResponse>,
2467)> {
2468 cx.update(|cx| {
2469 let mut settings = ProjectSettings::get_global(cx).clone();
2470 settings.context_servers.insert(
2471 name.into(),
2472 project::project_settings::ContextServerSettings::Custom {
2473 enabled: true,
2474 command: ContextServerCommand {
2475 path: "somebinary".into(),
2476 args: Vec::new(),
2477 env: None,
2478 },
2479 },
2480 );
2481 ProjectSettings::override_global(settings, cx);
2482 });
2483
2484 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2485 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2486 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2487 context_server::types::InitializeResponse {
2488 protocol_version: context_server::types::ProtocolVersion(
2489 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2490 ),
2491 server_info: context_server::types::Implementation {
2492 name: name.into(),
2493 version: "1.0.0".to_string(),
2494 },
2495 capabilities: context_server::types::ServerCapabilities {
2496 tools: Some(context_server::types::ToolsCapabilities {
2497 list_changed: Some(true),
2498 }),
2499 ..Default::default()
2500 },
2501 meta: None,
2502 }
2503 })
2504 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2505 let tools = tools.clone();
2506 async move {
2507 context_server::types::ListToolsResponse {
2508 tools,
2509 next_cursor: None,
2510 meta: None,
2511 }
2512 }
2513 })
2514 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2515 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2516 async move {
2517 let (response_tx, response_rx) = oneshot::channel();
2518 mcp_tool_calls_tx
2519 .unbounded_send((params, response_tx))
2520 .unwrap();
2521 response_rx.await.unwrap()
2522 }
2523 });
2524 context_server_store.update(cx, |store, cx| {
2525 store.start_server(
2526 Arc::new(ContextServer::new(
2527 ContextServerId(name.into()),
2528 Arc::new(fake_transport),
2529 )),
2530 cx,
2531 );
2532 });
2533 cx.run_until_parked();
2534 mcp_tool_calls_rx
2535}