1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use fs::{FakeFs, Fs};
11use futures::{
12 StreamExt,
13 channel::{
14 mpsc::{self, UnboundedReceiver},
15 oneshot,
16 },
17};
18use gpui::{
19 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
20};
21use indoc::indoc;
22use language_model::{
23 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
24 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
25 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
26 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
27};
28use pretty_assertions::assert_eq;
29use project::{
30 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
31};
32use prompt_store::ProjectContext;
33use reqwest_client::ReqwestClient;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use serde_json::json;
37use settings::{Settings, SettingsStore};
38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
39use util::path;
40
41mod test_tools;
42use test_tools::*;
43
44#[gpui::test]
45async fn test_echo(cx: &mut TestAppContext) {
46 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
47 let fake_model = model.as_fake();
48
49 let events = thread
50 .update(cx, |thread, cx| {
51 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
52 })
53 .unwrap();
54 cx.run_until_parked();
55 fake_model.send_last_completion_stream_text_chunk("Hello");
56 fake_model
57 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
58 fake_model.end_last_completion_stream();
59
60 let events = events.collect().await;
61 thread.update(cx, |thread, _cx| {
62 assert_eq!(
63 thread.last_message().unwrap().to_markdown(),
64 indoc! {"
65 ## Assistant
66
67 Hello
68 "}
69 )
70 });
71 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
72}
73
74#[gpui::test]
75async fn test_thinking(cx: &mut TestAppContext) {
76 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
77 let fake_model = model.as_fake();
78
79 let events = thread
80 .update(cx, |thread, cx| {
81 thread.send(
82 UserMessageId::new(),
83 [indoc! {"
84 Testing:
85
86 Generate a thinking step where you just think the word 'Think',
87 and have your final answer be 'Hello'
88 "}],
89 cx,
90 )
91 })
92 .unwrap();
93 cx.run_until_parked();
94 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
95 text: "Think".to_string(),
96 signature: None,
97 });
98 fake_model.send_last_completion_stream_text_chunk("Hello");
99 fake_model
100 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
101 fake_model.end_last_completion_stream();
102
103 let events = events.collect().await;
104 thread.update(cx, |thread, _cx| {
105 assert_eq!(
106 thread.last_message().unwrap().to_markdown(),
107 indoc! {"
108 ## Assistant
109
110 <think>Think</think>
111 Hello
112 "}
113 )
114 });
115 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
116}
117
118#[gpui::test]
119async fn test_system_prompt(cx: &mut TestAppContext) {
120 let ThreadTest {
121 model,
122 thread,
123 project_context,
124 ..
125 } = setup(cx, TestModel::Fake).await;
126 let fake_model = model.as_fake();
127
128 project_context.update(cx, |project_context, _cx| {
129 project_context.shell = "test-shell".into()
130 });
131 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
132 thread
133 .update(cx, |thread, cx| {
134 thread.send(UserMessageId::new(), ["abc"], cx)
135 })
136 .unwrap();
137 cx.run_until_parked();
138 let mut pending_completions = fake_model.pending_completions();
139 assert_eq!(
140 pending_completions.len(),
141 1,
142 "unexpected pending completions: {:?}",
143 pending_completions
144 );
145
146 let pending_completion = pending_completions.pop().unwrap();
147 assert_eq!(pending_completion.messages[0].role, Role::System);
148
149 let system_message = &pending_completion.messages[0];
150 let system_prompt = system_message.content[0].to_str().unwrap();
151 assert!(
152 system_prompt.contains("test-shell"),
153 "unexpected system message: {:?}",
154 system_message
155 );
156 assert!(
157 system_prompt.contains("## Fixing Diagnostics"),
158 "unexpected system message: {:?}",
159 system_message
160 );
161}
162
163#[gpui::test]
164async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
165 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
166 let fake_model = model.as_fake();
167
168 thread
169 .update(cx, |thread, cx| {
170 thread.send(UserMessageId::new(), ["abc"], cx)
171 })
172 .unwrap();
173 cx.run_until_parked();
174 let mut pending_completions = fake_model.pending_completions();
175 assert_eq!(
176 pending_completions.len(),
177 1,
178 "unexpected pending completions: {:?}",
179 pending_completions
180 );
181
182 let pending_completion = pending_completions.pop().unwrap();
183 assert_eq!(pending_completion.messages[0].role, Role::System);
184
185 let system_message = &pending_completion.messages[0];
186 let system_prompt = system_message.content[0].to_str().unwrap();
187 assert!(
188 !system_prompt.contains("## Tool Use"),
189 "unexpected system message: {:?}",
190 system_message
191 );
192 assert!(
193 !system_prompt.contains("## Fixing Diagnostics"),
194 "unexpected system message: {:?}",
195 system_message
196 );
197}
198
199#[gpui::test]
200async fn test_prompt_caching(cx: &mut TestAppContext) {
201 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
202 let fake_model = model.as_fake();
203
204 // Send initial user message and verify it's cached
205 thread
206 .update(cx, |thread, cx| {
207 thread.send(UserMessageId::new(), ["Message 1"], cx)
208 })
209 .unwrap();
210 cx.run_until_parked();
211
212 let completion = fake_model.pending_completions().pop().unwrap();
213 assert_eq!(
214 completion.messages[1..],
215 vec![LanguageModelRequestMessage {
216 role: Role::User,
217 content: vec!["Message 1".into()],
218 cache: true
219 }]
220 );
221 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
222 "Response to Message 1".into(),
223 ));
224 fake_model.end_last_completion_stream();
225 cx.run_until_parked();
226
227 // Send another user message and verify only the latest is cached
228 thread
229 .update(cx, |thread, cx| {
230 thread.send(UserMessageId::new(), ["Message 2"], cx)
231 })
232 .unwrap();
233 cx.run_until_parked();
234
235 let completion = fake_model.pending_completions().pop().unwrap();
236 assert_eq!(
237 completion.messages[1..],
238 vec![
239 LanguageModelRequestMessage {
240 role: Role::User,
241 content: vec!["Message 1".into()],
242 cache: false
243 },
244 LanguageModelRequestMessage {
245 role: Role::Assistant,
246 content: vec!["Response to Message 1".into()],
247 cache: false
248 },
249 LanguageModelRequestMessage {
250 role: Role::User,
251 content: vec!["Message 2".into()],
252 cache: true
253 }
254 ]
255 );
256 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
257 "Response to Message 2".into(),
258 ));
259 fake_model.end_last_completion_stream();
260 cx.run_until_parked();
261
262 // Simulate a tool call and verify that the latest tool result is cached
263 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
264 thread
265 .update(cx, |thread, cx| {
266 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
267 })
268 .unwrap();
269 cx.run_until_parked();
270
271 let tool_use = LanguageModelToolUse {
272 id: "tool_1".into(),
273 name: EchoTool::name().into(),
274 raw_input: json!({"text": "test"}).to_string(),
275 input: json!({"text": "test"}),
276 is_input_complete: true,
277 thought_signature: None,
278 };
279 fake_model
280 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
281 fake_model.end_last_completion_stream();
282 cx.run_until_parked();
283
284 let completion = fake_model.pending_completions().pop().unwrap();
285 let tool_result = LanguageModelToolResult {
286 tool_use_id: "tool_1".into(),
287 tool_name: EchoTool::name().into(),
288 is_error: false,
289 content: "test".into(),
290 output: Some("test".into()),
291 };
292 assert_eq!(
293 completion.messages[1..],
294 vec![
295 LanguageModelRequestMessage {
296 role: Role::User,
297 content: vec!["Message 1".into()],
298 cache: false
299 },
300 LanguageModelRequestMessage {
301 role: Role::Assistant,
302 content: vec!["Response to Message 1".into()],
303 cache: false
304 },
305 LanguageModelRequestMessage {
306 role: Role::User,
307 content: vec!["Message 2".into()],
308 cache: false
309 },
310 LanguageModelRequestMessage {
311 role: Role::Assistant,
312 content: vec!["Response to Message 2".into()],
313 cache: false
314 },
315 LanguageModelRequestMessage {
316 role: Role::User,
317 content: vec!["Use the echo tool".into()],
318 cache: false
319 },
320 LanguageModelRequestMessage {
321 role: Role::Assistant,
322 content: vec![MessageContent::ToolUse(tool_use)],
323 cache: false
324 },
325 LanguageModelRequestMessage {
326 role: Role::User,
327 content: vec![MessageContent::ToolResult(tool_result)],
328 cache: true
329 }
330 ]
331 );
332}
333
334#[gpui::test]
335#[cfg_attr(not(feature = "e2e"), ignore)]
336async fn test_basic_tool_calls(cx: &mut TestAppContext) {
337 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
338
339 // Test a tool call that's likely to complete *before* streaming stops.
340 let events = thread
341 .update(cx, |thread, cx| {
342 thread.add_tool(EchoTool);
343 thread.send(
344 UserMessageId::new(),
345 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
346 cx,
347 )
348 })
349 .unwrap()
350 .collect()
351 .await;
352 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
353
354 // Test a tool calls that's likely to complete *after* streaming stops.
355 let events = thread
356 .update(cx, |thread, cx| {
357 thread.remove_tool(&EchoTool::name());
358 thread.add_tool(DelayTool);
359 thread.send(
360 UserMessageId::new(),
361 [
362 "Now call the delay tool with 200ms.",
363 "When the timer goes off, then you echo the output of the tool.",
364 ],
365 cx,
366 )
367 })
368 .unwrap()
369 .collect()
370 .await;
371 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
372 thread.update(cx, |thread, _cx| {
373 assert!(
374 thread
375 .last_message()
376 .unwrap()
377 .as_agent_message()
378 .unwrap()
379 .content
380 .iter()
381 .any(|content| {
382 if let AgentMessageContent::Text(text) = content {
383 text.contains("Ding")
384 } else {
385 false
386 }
387 }),
388 "{}",
389 thread.to_markdown()
390 );
391 });
392}
393
394#[gpui::test]
395#[cfg_attr(not(feature = "e2e"), ignore)]
396async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
397 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
398
399 // Test a tool call that's likely to complete *before* streaming stops.
400 let mut events = thread
401 .update(cx, |thread, cx| {
402 thread.add_tool(WordListTool);
403 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
404 })
405 .unwrap();
406
407 let mut saw_partial_tool_use = false;
408 while let Some(event) = events.next().await {
409 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
410 thread.update(cx, |thread, _cx| {
411 // Look for a tool use in the thread's last message
412 let message = thread.last_message().unwrap();
413 let agent_message = message.as_agent_message().unwrap();
414 let last_content = agent_message.content.last().unwrap();
415 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
416 assert_eq!(last_tool_use.name.as_ref(), "word_list");
417 if tool_call.status == acp::ToolCallStatus::Pending {
418 if !last_tool_use.is_input_complete
419 && last_tool_use.input.get("g").is_none()
420 {
421 saw_partial_tool_use = true;
422 }
423 } else {
424 last_tool_use
425 .input
426 .get("a")
427 .expect("'a' has streamed because input is now complete");
428 last_tool_use
429 .input
430 .get("g")
431 .expect("'g' has streamed because input is now complete");
432 }
433 } else {
434 panic!("last content should be a tool use");
435 }
436 });
437 }
438 }
439
440 assert!(
441 saw_partial_tool_use,
442 "should see at least one partially streamed tool use in the history"
443 );
444}
445
446#[gpui::test]
447async fn test_tool_authorization(cx: &mut TestAppContext) {
448 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
449 let fake_model = model.as_fake();
450
451 let mut events = thread
452 .update(cx, |thread, cx| {
453 thread.add_tool(ToolRequiringPermission);
454 thread.send(UserMessageId::new(), ["abc"], cx)
455 })
456 .unwrap();
457 cx.run_until_parked();
458 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
459 LanguageModelToolUse {
460 id: "tool_id_1".into(),
461 name: ToolRequiringPermission::name().into(),
462 raw_input: "{}".into(),
463 input: json!({}),
464 is_input_complete: true,
465 thought_signature: None,
466 },
467 ));
468 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
469 LanguageModelToolUse {
470 id: "tool_id_2".into(),
471 name: ToolRequiringPermission::name().into(),
472 raw_input: "{}".into(),
473 input: json!({}),
474 is_input_complete: true,
475 thought_signature: None,
476 },
477 ));
478 fake_model.end_last_completion_stream();
479 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
480 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
481
482 // Approve the first
483 tool_call_auth_1
484 .response
485 .send(tool_call_auth_1.options[1].id.clone())
486 .unwrap();
487 cx.run_until_parked();
488
489 // Reject the second
490 tool_call_auth_2
491 .response
492 .send(tool_call_auth_1.options[2].id.clone())
493 .unwrap();
494 cx.run_until_parked();
495
496 let completion = fake_model.pending_completions().pop().unwrap();
497 let message = completion.messages.last().unwrap();
498 assert_eq!(
499 message.content,
500 vec![
501 language_model::MessageContent::ToolResult(LanguageModelToolResult {
502 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
503 tool_name: ToolRequiringPermission::name().into(),
504 is_error: false,
505 content: "Allowed".into(),
506 output: Some("Allowed".into())
507 }),
508 language_model::MessageContent::ToolResult(LanguageModelToolResult {
509 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
510 tool_name: ToolRequiringPermission::name().into(),
511 is_error: true,
512 content: "Permission to run tool denied by user".into(),
513 output: Some("Permission to run tool denied by user".into())
514 })
515 ]
516 );
517
518 // Simulate yet another tool call.
519 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
520 LanguageModelToolUse {
521 id: "tool_id_3".into(),
522 name: ToolRequiringPermission::name().into(),
523 raw_input: "{}".into(),
524 input: json!({}),
525 is_input_complete: true,
526 thought_signature: None,
527 },
528 ));
529 fake_model.end_last_completion_stream();
530
531 // Respond by always allowing tools.
532 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
533 tool_call_auth_3
534 .response
535 .send(tool_call_auth_3.options[0].id.clone())
536 .unwrap();
537 cx.run_until_parked();
538 let completion = fake_model.pending_completions().pop().unwrap();
539 let message = completion.messages.last().unwrap();
540 assert_eq!(
541 message.content,
542 vec![language_model::MessageContent::ToolResult(
543 LanguageModelToolResult {
544 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
545 tool_name: ToolRequiringPermission::name().into(),
546 is_error: false,
547 content: "Allowed".into(),
548 output: Some("Allowed".into())
549 }
550 )]
551 );
552
553 // Simulate a final tool call, ensuring we don't trigger authorization.
554 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
555 LanguageModelToolUse {
556 id: "tool_id_4".into(),
557 name: ToolRequiringPermission::name().into(),
558 raw_input: "{}".into(),
559 input: json!({}),
560 is_input_complete: true,
561 thought_signature: None,
562 },
563 ));
564 fake_model.end_last_completion_stream();
565 cx.run_until_parked();
566 let completion = fake_model.pending_completions().pop().unwrap();
567 let message = completion.messages.last().unwrap();
568 assert_eq!(
569 message.content,
570 vec![language_model::MessageContent::ToolResult(
571 LanguageModelToolResult {
572 tool_use_id: "tool_id_4".into(),
573 tool_name: ToolRequiringPermission::name().into(),
574 is_error: false,
575 content: "Allowed".into(),
576 output: Some("Allowed".into())
577 }
578 )]
579 );
580}
581
582#[gpui::test]
583async fn test_tool_hallucination(cx: &mut TestAppContext) {
584 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
585 let fake_model = model.as_fake();
586
587 let mut events = thread
588 .update(cx, |thread, cx| {
589 thread.send(UserMessageId::new(), ["abc"], cx)
590 })
591 .unwrap();
592 cx.run_until_parked();
593 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
594 LanguageModelToolUse {
595 id: "tool_id_1".into(),
596 name: "nonexistent_tool".into(),
597 raw_input: "{}".into(),
598 input: json!({}),
599 is_input_complete: true,
600 thought_signature: None,
601 },
602 ));
603 fake_model.end_last_completion_stream();
604
605 let tool_call = expect_tool_call(&mut events).await;
606 assert_eq!(tool_call.title, "nonexistent_tool");
607 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
608 let update = expect_tool_call_update_fields(&mut events).await;
609 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
610}
611
612#[gpui::test]
613async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
614 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
615 let fake_model = model.as_fake();
616
617 let events = thread
618 .update(cx, |thread, cx| {
619 thread.add_tool(EchoTool);
620 thread.send(UserMessageId::new(), ["abc"], cx)
621 })
622 .unwrap();
623 cx.run_until_parked();
624 let tool_use = LanguageModelToolUse {
625 id: "tool_id_1".into(),
626 name: EchoTool::name().into(),
627 raw_input: "{}".into(),
628 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
629 is_input_complete: true,
630 thought_signature: None,
631 };
632 fake_model
633 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
634 fake_model.end_last_completion_stream();
635
636 cx.run_until_parked();
637 let completion = fake_model.pending_completions().pop().unwrap();
638 let tool_result = LanguageModelToolResult {
639 tool_use_id: "tool_id_1".into(),
640 tool_name: EchoTool::name().into(),
641 is_error: false,
642 content: "def".into(),
643 output: Some("def".into()),
644 };
645 assert_eq!(
646 completion.messages[1..],
647 vec![
648 LanguageModelRequestMessage {
649 role: Role::User,
650 content: vec!["abc".into()],
651 cache: false
652 },
653 LanguageModelRequestMessage {
654 role: Role::Assistant,
655 content: vec![MessageContent::ToolUse(tool_use.clone())],
656 cache: false
657 },
658 LanguageModelRequestMessage {
659 role: Role::User,
660 content: vec![MessageContent::ToolResult(tool_result.clone())],
661 cache: true
662 },
663 ]
664 );
665
666 // Simulate reaching tool use limit.
667 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
668 fake_model.end_last_completion_stream();
669 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
670 assert!(
671 last_event
672 .unwrap_err()
673 .is::<language_model::ToolUseLimitReachedError>()
674 );
675
676 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
677 cx.run_until_parked();
678 let completion = fake_model.pending_completions().pop().unwrap();
679 assert_eq!(
680 completion.messages[1..],
681 vec![
682 LanguageModelRequestMessage {
683 role: Role::User,
684 content: vec!["abc".into()],
685 cache: false
686 },
687 LanguageModelRequestMessage {
688 role: Role::Assistant,
689 content: vec![MessageContent::ToolUse(tool_use)],
690 cache: false
691 },
692 LanguageModelRequestMessage {
693 role: Role::User,
694 content: vec![MessageContent::ToolResult(tool_result)],
695 cache: false
696 },
697 LanguageModelRequestMessage {
698 role: Role::User,
699 content: vec!["Continue where you left off".into()],
700 cache: true
701 }
702 ]
703 );
704
705 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
706 fake_model.end_last_completion_stream();
707 events.collect::<Vec<_>>().await;
708 thread.read_with(cx, |thread, _cx| {
709 assert_eq!(
710 thread.last_message().unwrap().to_markdown(),
711 indoc! {"
712 ## Assistant
713
714 Done
715 "}
716 )
717 });
718}
719
720#[gpui::test]
721async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
722 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
723 let fake_model = model.as_fake();
724
725 let events = thread
726 .update(cx, |thread, cx| {
727 thread.add_tool(EchoTool);
728 thread.send(UserMessageId::new(), ["abc"], cx)
729 })
730 .unwrap();
731 cx.run_until_parked();
732
733 let tool_use = LanguageModelToolUse {
734 id: "tool_id_1".into(),
735 name: EchoTool::name().into(),
736 raw_input: "{}".into(),
737 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
738 is_input_complete: true,
739 thought_signature: None,
740 };
741 let tool_result = LanguageModelToolResult {
742 tool_use_id: "tool_id_1".into(),
743 tool_name: EchoTool::name().into(),
744 is_error: false,
745 content: "def".into(),
746 output: Some("def".into()),
747 };
748 fake_model
749 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
750 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
751 fake_model.end_last_completion_stream();
752 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
753 assert!(
754 last_event
755 .unwrap_err()
756 .is::<language_model::ToolUseLimitReachedError>()
757 );
758
759 thread
760 .update(cx, |thread, cx| {
761 thread.send(UserMessageId::new(), vec!["ghi"], cx)
762 })
763 .unwrap();
764 cx.run_until_parked();
765 let completion = fake_model.pending_completions().pop().unwrap();
766 assert_eq!(
767 completion.messages[1..],
768 vec![
769 LanguageModelRequestMessage {
770 role: Role::User,
771 content: vec!["abc".into()],
772 cache: false
773 },
774 LanguageModelRequestMessage {
775 role: Role::Assistant,
776 content: vec![MessageContent::ToolUse(tool_use)],
777 cache: false
778 },
779 LanguageModelRequestMessage {
780 role: Role::User,
781 content: vec![MessageContent::ToolResult(tool_result)],
782 cache: false
783 },
784 LanguageModelRequestMessage {
785 role: Role::User,
786 content: vec!["ghi".into()],
787 cache: true
788 }
789 ]
790 );
791}
792
793async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
794 let event = events
795 .next()
796 .await
797 .expect("no tool call authorization event received")
798 .unwrap();
799 match event {
800 ThreadEvent::ToolCall(tool_call) => tool_call,
801 event => {
802 panic!("Unexpected event {event:?}");
803 }
804 }
805}
806
807async fn expect_tool_call_update_fields(
808 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
809) -> acp::ToolCallUpdate {
810 let event = events
811 .next()
812 .await
813 .expect("no tool call authorization event received")
814 .unwrap();
815 match event {
816 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
817 event => {
818 panic!("Unexpected event {event:?}");
819 }
820 }
821}
822
823async fn next_tool_call_authorization(
824 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
825) -> ToolCallAuthorization {
826 loop {
827 let event = events
828 .next()
829 .await
830 .expect("no tool call authorization event received")
831 .unwrap();
832 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
833 let permission_kinds = tool_call_authorization
834 .options
835 .iter()
836 .map(|o| o.kind)
837 .collect::<Vec<_>>();
838 assert_eq!(
839 permission_kinds,
840 vec![
841 acp::PermissionOptionKind::AllowAlways,
842 acp::PermissionOptionKind::AllowOnce,
843 acp::PermissionOptionKind::RejectOnce,
844 ]
845 );
846 return tool_call_authorization;
847 }
848 }
849}
850
851#[gpui::test]
852#[cfg_attr(not(feature = "e2e"), ignore)]
853async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
854 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
855
856 // Test concurrent tool calls with different delay times
857 let events = thread
858 .update(cx, |thread, cx| {
859 thread.add_tool(DelayTool);
860 thread.send(
861 UserMessageId::new(),
862 [
863 "Call the delay tool twice in the same message.",
864 "Once with 100ms. Once with 300ms.",
865 "When both timers are complete, describe the outputs.",
866 ],
867 cx,
868 )
869 })
870 .unwrap()
871 .collect()
872 .await;
873
874 let stop_reasons = stop_events(events);
875 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
876
877 thread.update(cx, |thread, _cx| {
878 let last_message = thread.last_message().unwrap();
879 let agent_message = last_message.as_agent_message().unwrap();
880 let text = agent_message
881 .content
882 .iter()
883 .filter_map(|content| {
884 if let AgentMessageContent::Text(text) = content {
885 Some(text.as_str())
886 } else {
887 None
888 }
889 })
890 .collect::<String>();
891
892 assert!(text.contains("Ding"));
893 });
894}
895
896#[gpui::test]
897async fn test_profiles(cx: &mut TestAppContext) {
898 let ThreadTest {
899 model, thread, fs, ..
900 } = setup(cx, TestModel::Fake).await;
901 let fake_model = model.as_fake();
902
903 thread.update(cx, |thread, _cx| {
904 thread.add_tool(DelayTool);
905 thread.add_tool(EchoTool);
906 thread.add_tool(InfiniteTool);
907 });
908
909 // Override profiles and wait for settings to be loaded.
910 fs.insert_file(
911 paths::settings_file(),
912 json!({
913 "agent": {
914 "profiles": {
915 "test-1": {
916 "name": "Test Profile 1",
917 "tools": {
918 EchoTool::name(): true,
919 DelayTool::name(): true,
920 }
921 },
922 "test-2": {
923 "name": "Test Profile 2",
924 "tools": {
925 InfiniteTool::name(): true,
926 }
927 }
928 }
929 }
930 })
931 .to_string()
932 .into_bytes(),
933 )
934 .await;
935 cx.run_until_parked();
936
937 // Test that test-1 profile (default) has echo and delay tools
938 thread
939 .update(cx, |thread, cx| {
940 thread.set_profile(AgentProfileId("test-1".into()), cx);
941 thread.send(UserMessageId::new(), ["test"], cx)
942 })
943 .unwrap();
944 cx.run_until_parked();
945
946 let mut pending_completions = fake_model.pending_completions();
947 assert_eq!(pending_completions.len(), 1);
948 let completion = pending_completions.pop().unwrap();
949 let tool_names: Vec<String> = completion
950 .tools
951 .iter()
952 .map(|tool| tool.name.clone())
953 .collect();
954 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
955 fake_model.end_last_completion_stream();
956
957 // Switch to test-2 profile, and verify that it has only the infinite tool.
958 thread
959 .update(cx, |thread, cx| {
960 thread.set_profile(AgentProfileId("test-2".into()), cx);
961 thread.send(UserMessageId::new(), ["test2"], cx)
962 })
963 .unwrap();
964 cx.run_until_parked();
965 let mut pending_completions = fake_model.pending_completions();
966 assert_eq!(pending_completions.len(), 1);
967 let completion = pending_completions.pop().unwrap();
968 let tool_names: Vec<String> = completion
969 .tools
970 .iter()
971 .map(|tool| tool.name.clone())
972 .collect();
973 assert_eq!(tool_names, vec![InfiniteTool::name()]);
974}
975
976#[gpui::test]
977async fn test_mcp_tools(cx: &mut TestAppContext) {
978 let ThreadTest {
979 model,
980 thread,
981 context_server_store,
982 fs,
983 ..
984 } = setup(cx, TestModel::Fake).await;
985 let fake_model = model.as_fake();
986
987 // Override profiles and wait for settings to be loaded.
988 fs.insert_file(
989 paths::settings_file(),
990 json!({
991 "agent": {
992 "always_allow_tool_actions": true,
993 "profiles": {
994 "test": {
995 "name": "Test Profile",
996 "enable_all_context_servers": true,
997 "tools": {
998 EchoTool::name(): true,
999 }
1000 },
1001 }
1002 }
1003 })
1004 .to_string()
1005 .into_bytes(),
1006 )
1007 .await;
1008 cx.run_until_parked();
1009 thread.update(cx, |thread, cx| {
1010 thread.set_profile(AgentProfileId("test".into()), cx)
1011 });
1012
1013 let mut mcp_tool_calls = setup_context_server(
1014 "test_server",
1015 vec![context_server::types::Tool {
1016 name: "echo".into(),
1017 description: None,
1018 input_schema: serde_json::to_value(EchoTool::input_schema(
1019 LanguageModelToolSchemaFormat::JsonSchema,
1020 ))
1021 .unwrap(),
1022 output_schema: None,
1023 annotations: None,
1024 }],
1025 &context_server_store,
1026 cx,
1027 );
1028
1029 let events = thread.update(cx, |thread, cx| {
1030 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1031 });
1032 cx.run_until_parked();
1033
1034 // Simulate the model calling the MCP tool.
1035 let completion = fake_model.pending_completions().pop().unwrap();
1036 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1037 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1038 LanguageModelToolUse {
1039 id: "tool_1".into(),
1040 name: "echo".into(),
1041 raw_input: json!({"text": "test"}).to_string(),
1042 input: json!({"text": "test"}),
1043 is_input_complete: true,
1044 thought_signature: None,
1045 },
1046 ));
1047 fake_model.end_last_completion_stream();
1048 cx.run_until_parked();
1049
1050 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1051 assert_eq!(tool_call_params.name, "echo");
1052 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1053 tool_call_response
1054 .send(context_server::types::CallToolResponse {
1055 content: vec![context_server::types::ToolResponseContent::Text {
1056 text: "test".into(),
1057 }],
1058 is_error: None,
1059 meta: None,
1060 structured_content: None,
1061 })
1062 .unwrap();
1063 cx.run_until_parked();
1064
1065 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1066 fake_model.send_last_completion_stream_text_chunk("Done!");
1067 fake_model.end_last_completion_stream();
1068 events.collect::<Vec<_>>().await;
1069
1070 // Send again after adding the echo tool, ensuring the name collision is resolved.
1071 let events = thread.update(cx, |thread, cx| {
1072 thread.add_tool(EchoTool);
1073 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1074 });
1075 cx.run_until_parked();
1076 let completion = fake_model.pending_completions().pop().unwrap();
1077 assert_eq!(
1078 tool_names_for_completion(&completion),
1079 vec!["echo", "test_server_echo"]
1080 );
1081 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1082 LanguageModelToolUse {
1083 id: "tool_2".into(),
1084 name: "test_server_echo".into(),
1085 raw_input: json!({"text": "mcp"}).to_string(),
1086 input: json!({"text": "mcp"}),
1087 is_input_complete: true,
1088 thought_signature: None,
1089 },
1090 ));
1091 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1092 LanguageModelToolUse {
1093 id: "tool_3".into(),
1094 name: "echo".into(),
1095 raw_input: json!({"text": "native"}).to_string(),
1096 input: json!({"text": "native"}),
1097 is_input_complete: true,
1098 thought_signature: None,
1099 },
1100 ));
1101 fake_model.end_last_completion_stream();
1102 cx.run_until_parked();
1103
1104 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1105 assert_eq!(tool_call_params.name, "echo");
1106 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1107 tool_call_response
1108 .send(context_server::types::CallToolResponse {
1109 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1110 is_error: None,
1111 meta: None,
1112 structured_content: None,
1113 })
1114 .unwrap();
1115 cx.run_until_parked();
1116
1117 // Ensure the tool results were inserted with the correct names.
1118 let completion = fake_model.pending_completions().pop().unwrap();
1119 assert_eq!(
1120 completion.messages.last().unwrap().content,
1121 vec![
1122 MessageContent::ToolResult(LanguageModelToolResult {
1123 tool_use_id: "tool_3".into(),
1124 tool_name: "echo".into(),
1125 is_error: false,
1126 content: "native".into(),
1127 output: Some("native".into()),
1128 },),
1129 MessageContent::ToolResult(LanguageModelToolResult {
1130 tool_use_id: "tool_2".into(),
1131 tool_name: "test_server_echo".into(),
1132 is_error: false,
1133 content: "mcp".into(),
1134 output: Some("mcp".into()),
1135 },),
1136 ]
1137 );
1138 fake_model.end_last_completion_stream();
1139 events.collect::<Vec<_>>().await;
1140}
1141
1142#[gpui::test]
1143async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1144 let ThreadTest {
1145 model,
1146 thread,
1147 context_server_store,
1148 fs,
1149 ..
1150 } = setup(cx, TestModel::Fake).await;
1151 let fake_model = model.as_fake();
1152
1153 // Set up a profile with all tools enabled
1154 fs.insert_file(
1155 paths::settings_file(),
1156 json!({
1157 "agent": {
1158 "profiles": {
1159 "test": {
1160 "name": "Test Profile",
1161 "enable_all_context_servers": true,
1162 "tools": {
1163 EchoTool::name(): true,
1164 DelayTool::name(): true,
1165 WordListTool::name(): true,
1166 ToolRequiringPermission::name(): true,
1167 InfiniteTool::name(): true,
1168 }
1169 },
1170 }
1171 }
1172 })
1173 .to_string()
1174 .into_bytes(),
1175 )
1176 .await;
1177 cx.run_until_parked();
1178
1179 thread.update(cx, |thread, cx| {
1180 thread.set_profile(AgentProfileId("test".into()), cx);
1181 thread.add_tool(EchoTool);
1182 thread.add_tool(DelayTool);
1183 thread.add_tool(WordListTool);
1184 thread.add_tool(ToolRequiringPermission);
1185 thread.add_tool(InfiniteTool);
1186 });
1187
1188 // Set up multiple context servers with some overlapping tool names
1189 let _server1_calls = setup_context_server(
1190 "xxx",
1191 vec![
1192 context_server::types::Tool {
1193 name: "echo".into(), // Conflicts with native EchoTool
1194 description: None,
1195 input_schema: serde_json::to_value(EchoTool::input_schema(
1196 LanguageModelToolSchemaFormat::JsonSchema,
1197 ))
1198 .unwrap(),
1199 output_schema: None,
1200 annotations: None,
1201 },
1202 context_server::types::Tool {
1203 name: "unique_tool_1".into(),
1204 description: None,
1205 input_schema: json!({"type": "object", "properties": {}}),
1206 output_schema: None,
1207 annotations: None,
1208 },
1209 ],
1210 &context_server_store,
1211 cx,
1212 );
1213
1214 let _server2_calls = setup_context_server(
1215 "yyy",
1216 vec![
1217 context_server::types::Tool {
1218 name: "echo".into(), // Also conflicts with native EchoTool
1219 description: None,
1220 input_schema: serde_json::to_value(EchoTool::input_schema(
1221 LanguageModelToolSchemaFormat::JsonSchema,
1222 ))
1223 .unwrap(),
1224 output_schema: None,
1225 annotations: None,
1226 },
1227 context_server::types::Tool {
1228 name: "unique_tool_2".into(),
1229 description: None,
1230 input_schema: json!({"type": "object", "properties": {}}),
1231 output_schema: None,
1232 annotations: None,
1233 },
1234 context_server::types::Tool {
1235 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1236 description: None,
1237 input_schema: json!({"type": "object", "properties": {}}),
1238 output_schema: None,
1239 annotations: None,
1240 },
1241 context_server::types::Tool {
1242 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1243 description: None,
1244 input_schema: json!({"type": "object", "properties": {}}),
1245 output_schema: None,
1246 annotations: None,
1247 },
1248 ],
1249 &context_server_store,
1250 cx,
1251 );
1252 let _server3_calls = setup_context_server(
1253 "zzz",
1254 vec![
1255 context_server::types::Tool {
1256 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1257 description: None,
1258 input_schema: json!({"type": "object", "properties": {}}),
1259 output_schema: None,
1260 annotations: None,
1261 },
1262 context_server::types::Tool {
1263 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1264 description: None,
1265 input_schema: json!({"type": "object", "properties": {}}),
1266 output_schema: None,
1267 annotations: None,
1268 },
1269 context_server::types::Tool {
1270 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1271 description: None,
1272 input_schema: json!({"type": "object", "properties": {}}),
1273 output_schema: None,
1274 annotations: None,
1275 },
1276 ],
1277 &context_server_store,
1278 cx,
1279 );
1280
1281 thread
1282 .update(cx, |thread, cx| {
1283 thread.send(UserMessageId::new(), ["Go"], cx)
1284 })
1285 .unwrap();
1286 cx.run_until_parked();
1287 let completion = fake_model.pending_completions().pop().unwrap();
1288 assert_eq!(
1289 tool_names_for_completion(&completion),
1290 vec![
1291 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1292 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1293 "delay",
1294 "echo",
1295 "infinite",
1296 "tool_requiring_permission",
1297 "unique_tool_1",
1298 "unique_tool_2",
1299 "word_list",
1300 "xxx_echo",
1301 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1302 "yyy_echo",
1303 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1304 ]
1305 );
1306}
1307
1308#[gpui::test]
1309#[cfg_attr(not(feature = "e2e"), ignore)]
1310async fn test_cancellation(cx: &mut TestAppContext) {
1311 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1312
1313 let mut events = thread
1314 .update(cx, |thread, cx| {
1315 thread.add_tool(InfiniteTool);
1316 thread.add_tool(EchoTool);
1317 thread.send(
1318 UserMessageId::new(),
1319 ["Call the echo tool, then call the infinite tool, then explain their output"],
1320 cx,
1321 )
1322 })
1323 .unwrap();
1324
1325 // Wait until both tools are called.
1326 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1327 let mut echo_id = None;
1328 let mut echo_completed = false;
1329 while let Some(event) = events.next().await {
1330 match event.unwrap() {
1331 ThreadEvent::ToolCall(tool_call) => {
1332 assert_eq!(tool_call.title, expected_tools.remove(0));
1333 if tool_call.title == "Echo" {
1334 echo_id = Some(tool_call.id);
1335 }
1336 }
1337 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1338 acp::ToolCallUpdate {
1339 id,
1340 fields:
1341 acp::ToolCallUpdateFields {
1342 status: Some(acp::ToolCallStatus::Completed),
1343 ..
1344 },
1345 meta: None,
1346 },
1347 )) if Some(&id) == echo_id.as_ref() => {
1348 echo_completed = true;
1349 }
1350 _ => {}
1351 }
1352
1353 if expected_tools.is_empty() && echo_completed {
1354 break;
1355 }
1356 }
1357
1358 // Cancel the current send and ensure that the event stream is closed, even
1359 // if one of the tools is still running.
1360 thread.update(cx, |thread, cx| thread.cancel(cx));
1361 let events = events.collect::<Vec<_>>().await;
1362 let last_event = events.last();
1363 assert!(
1364 matches!(
1365 last_event,
1366 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1367 ),
1368 "unexpected event {last_event:?}"
1369 );
1370
1371 // Ensure we can still send a new message after cancellation.
1372 let events = thread
1373 .update(cx, |thread, cx| {
1374 thread.send(
1375 UserMessageId::new(),
1376 ["Testing: reply with 'Hello' then stop."],
1377 cx,
1378 )
1379 })
1380 .unwrap()
1381 .collect::<Vec<_>>()
1382 .await;
1383 thread.update(cx, |thread, _cx| {
1384 let message = thread.last_message().unwrap();
1385 let agent_message = message.as_agent_message().unwrap();
1386 assert_eq!(
1387 agent_message.content,
1388 vec![AgentMessageContent::Text("Hello".to_string())]
1389 );
1390 });
1391 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1392}
1393
1394#[gpui::test]
1395async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1396 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1397 let fake_model = model.as_fake();
1398
1399 let events_1 = thread
1400 .update(cx, |thread, cx| {
1401 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1402 })
1403 .unwrap();
1404 cx.run_until_parked();
1405 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1406 cx.run_until_parked();
1407
1408 let events_2 = thread
1409 .update(cx, |thread, cx| {
1410 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1411 })
1412 .unwrap();
1413 cx.run_until_parked();
1414 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1415 fake_model
1416 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1417 fake_model.end_last_completion_stream();
1418
1419 let events_1 = events_1.collect::<Vec<_>>().await;
1420 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1421 let events_2 = events_2.collect::<Vec<_>>().await;
1422 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1423}
1424
1425#[gpui::test]
1426async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1427 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1428 let fake_model = model.as_fake();
1429
1430 let events_1 = thread
1431 .update(cx, |thread, cx| {
1432 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1433 })
1434 .unwrap();
1435 cx.run_until_parked();
1436 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1437 fake_model
1438 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1439 fake_model.end_last_completion_stream();
1440 let events_1 = events_1.collect::<Vec<_>>().await;
1441
1442 let events_2 = thread
1443 .update(cx, |thread, cx| {
1444 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1445 })
1446 .unwrap();
1447 cx.run_until_parked();
1448 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1449 fake_model
1450 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1451 fake_model.end_last_completion_stream();
1452 let events_2 = events_2.collect::<Vec<_>>().await;
1453
1454 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1455 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1456}
1457
1458#[gpui::test]
1459async fn test_refusal(cx: &mut TestAppContext) {
1460 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1461 let fake_model = model.as_fake();
1462
1463 let events = thread
1464 .update(cx, |thread, cx| {
1465 thread.send(UserMessageId::new(), ["Hello"], cx)
1466 })
1467 .unwrap();
1468 cx.run_until_parked();
1469 thread.read_with(cx, |thread, _| {
1470 assert_eq!(
1471 thread.to_markdown(),
1472 indoc! {"
1473 ## User
1474
1475 Hello
1476 "}
1477 );
1478 });
1479
1480 fake_model.send_last_completion_stream_text_chunk("Hey!");
1481 cx.run_until_parked();
1482 thread.read_with(cx, |thread, _| {
1483 assert_eq!(
1484 thread.to_markdown(),
1485 indoc! {"
1486 ## User
1487
1488 Hello
1489
1490 ## Assistant
1491
1492 Hey!
1493 "}
1494 );
1495 });
1496
1497 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1498 fake_model
1499 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1500 let events = events.collect::<Vec<_>>().await;
1501 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1502 thread.read_with(cx, |thread, _| {
1503 assert_eq!(thread.to_markdown(), "");
1504 });
1505}
1506
1507#[gpui::test]
1508async fn test_truncate_first_message(cx: &mut TestAppContext) {
1509 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1510 let fake_model = model.as_fake();
1511
1512 let message_id = UserMessageId::new();
1513 thread
1514 .update(cx, |thread, cx| {
1515 thread.send(message_id.clone(), ["Hello"], cx)
1516 })
1517 .unwrap();
1518 cx.run_until_parked();
1519 thread.read_with(cx, |thread, _| {
1520 assert_eq!(
1521 thread.to_markdown(),
1522 indoc! {"
1523 ## User
1524
1525 Hello
1526 "}
1527 );
1528 assert_eq!(thread.latest_token_usage(), None);
1529 });
1530
1531 fake_model.send_last_completion_stream_text_chunk("Hey!");
1532 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1533 language_model::TokenUsage {
1534 input_tokens: 32_000,
1535 output_tokens: 16_000,
1536 cache_creation_input_tokens: 0,
1537 cache_read_input_tokens: 0,
1538 },
1539 ));
1540 cx.run_until_parked();
1541 thread.read_with(cx, |thread, _| {
1542 assert_eq!(
1543 thread.to_markdown(),
1544 indoc! {"
1545 ## User
1546
1547 Hello
1548
1549 ## Assistant
1550
1551 Hey!
1552 "}
1553 );
1554 assert_eq!(
1555 thread.latest_token_usage(),
1556 Some(acp_thread::TokenUsage {
1557 used_tokens: 32_000 + 16_000,
1558 max_tokens: 1_000_000,
1559 })
1560 );
1561 });
1562
1563 thread
1564 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1565 .unwrap();
1566 cx.run_until_parked();
1567 thread.read_with(cx, |thread, _| {
1568 assert_eq!(thread.to_markdown(), "");
1569 assert_eq!(thread.latest_token_usage(), None);
1570 });
1571
1572 // Ensure we can still send a new message after truncation.
1573 thread
1574 .update(cx, |thread, cx| {
1575 thread.send(UserMessageId::new(), ["Hi"], cx)
1576 })
1577 .unwrap();
1578 thread.update(cx, |thread, _cx| {
1579 assert_eq!(
1580 thread.to_markdown(),
1581 indoc! {"
1582 ## User
1583
1584 Hi
1585 "}
1586 );
1587 });
1588 cx.run_until_parked();
1589 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1590 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1591 language_model::TokenUsage {
1592 input_tokens: 40_000,
1593 output_tokens: 20_000,
1594 cache_creation_input_tokens: 0,
1595 cache_read_input_tokens: 0,
1596 },
1597 ));
1598 cx.run_until_parked();
1599 thread.read_with(cx, |thread, _| {
1600 assert_eq!(
1601 thread.to_markdown(),
1602 indoc! {"
1603 ## User
1604
1605 Hi
1606
1607 ## Assistant
1608
1609 Ahoy!
1610 "}
1611 );
1612
1613 assert_eq!(
1614 thread.latest_token_usage(),
1615 Some(acp_thread::TokenUsage {
1616 used_tokens: 40_000 + 20_000,
1617 max_tokens: 1_000_000,
1618 })
1619 );
1620 });
1621}
1622
1623#[gpui::test]
1624async fn test_truncate_second_message(cx: &mut TestAppContext) {
1625 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1626 let fake_model = model.as_fake();
1627
1628 thread
1629 .update(cx, |thread, cx| {
1630 thread.send(UserMessageId::new(), ["Message 1"], cx)
1631 })
1632 .unwrap();
1633 cx.run_until_parked();
1634 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1635 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1636 language_model::TokenUsage {
1637 input_tokens: 32_000,
1638 output_tokens: 16_000,
1639 cache_creation_input_tokens: 0,
1640 cache_read_input_tokens: 0,
1641 },
1642 ));
1643 fake_model.end_last_completion_stream();
1644 cx.run_until_parked();
1645
1646 let assert_first_message_state = |cx: &mut TestAppContext| {
1647 thread.clone().read_with(cx, |thread, _| {
1648 assert_eq!(
1649 thread.to_markdown(),
1650 indoc! {"
1651 ## User
1652
1653 Message 1
1654
1655 ## Assistant
1656
1657 Message 1 response
1658 "}
1659 );
1660
1661 assert_eq!(
1662 thread.latest_token_usage(),
1663 Some(acp_thread::TokenUsage {
1664 used_tokens: 32_000 + 16_000,
1665 max_tokens: 1_000_000,
1666 })
1667 );
1668 });
1669 };
1670
1671 assert_first_message_state(cx);
1672
1673 let second_message_id = UserMessageId::new();
1674 thread
1675 .update(cx, |thread, cx| {
1676 thread.send(second_message_id.clone(), ["Message 2"], cx)
1677 })
1678 .unwrap();
1679 cx.run_until_parked();
1680
1681 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1682 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1683 language_model::TokenUsage {
1684 input_tokens: 40_000,
1685 output_tokens: 20_000,
1686 cache_creation_input_tokens: 0,
1687 cache_read_input_tokens: 0,
1688 },
1689 ));
1690 fake_model.end_last_completion_stream();
1691 cx.run_until_parked();
1692
1693 thread.read_with(cx, |thread, _| {
1694 assert_eq!(
1695 thread.to_markdown(),
1696 indoc! {"
1697 ## User
1698
1699 Message 1
1700
1701 ## Assistant
1702
1703 Message 1 response
1704
1705 ## User
1706
1707 Message 2
1708
1709 ## Assistant
1710
1711 Message 2 response
1712 "}
1713 );
1714
1715 assert_eq!(
1716 thread.latest_token_usage(),
1717 Some(acp_thread::TokenUsage {
1718 used_tokens: 40_000 + 20_000,
1719 max_tokens: 1_000_000,
1720 })
1721 );
1722 });
1723
1724 thread
1725 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1726 .unwrap();
1727 cx.run_until_parked();
1728
1729 assert_first_message_state(cx);
1730}
1731
1732#[gpui::test]
1733async fn test_title_generation(cx: &mut TestAppContext) {
1734 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1735 let fake_model = model.as_fake();
1736
1737 let summary_model = Arc::new(FakeLanguageModel::default());
1738 thread.update(cx, |thread, cx| {
1739 thread.set_summarization_model(Some(summary_model.clone()), cx)
1740 });
1741
1742 let send = thread
1743 .update(cx, |thread, cx| {
1744 thread.send(UserMessageId::new(), ["Hello"], cx)
1745 })
1746 .unwrap();
1747 cx.run_until_parked();
1748
1749 fake_model.send_last_completion_stream_text_chunk("Hey!");
1750 fake_model.end_last_completion_stream();
1751 cx.run_until_parked();
1752 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1753
1754 // Ensure the summary model has been invoked to generate a title.
1755 summary_model.send_last_completion_stream_text_chunk("Hello ");
1756 summary_model.send_last_completion_stream_text_chunk("world\nG");
1757 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1758 summary_model.end_last_completion_stream();
1759 send.collect::<Vec<_>>().await;
1760 cx.run_until_parked();
1761 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1762
1763 // Send another message, ensuring no title is generated this time.
1764 let send = thread
1765 .update(cx, |thread, cx| {
1766 thread.send(UserMessageId::new(), ["Hello again"], cx)
1767 })
1768 .unwrap();
1769 cx.run_until_parked();
1770 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1771 fake_model.end_last_completion_stream();
1772 cx.run_until_parked();
1773 assert_eq!(summary_model.pending_completions(), Vec::new());
1774 send.collect::<Vec<_>>().await;
1775 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1776}
1777
1778#[gpui::test]
1779async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1780 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1781 let fake_model = model.as_fake();
1782
1783 let _events = thread
1784 .update(cx, |thread, cx| {
1785 thread.add_tool(ToolRequiringPermission);
1786 thread.add_tool(EchoTool);
1787 thread.send(UserMessageId::new(), ["Hey!"], cx)
1788 })
1789 .unwrap();
1790 cx.run_until_parked();
1791
1792 let permission_tool_use = LanguageModelToolUse {
1793 id: "tool_id_1".into(),
1794 name: ToolRequiringPermission::name().into(),
1795 raw_input: "{}".into(),
1796 input: json!({}),
1797 is_input_complete: true,
1798 thought_signature: None,
1799 };
1800 let echo_tool_use = LanguageModelToolUse {
1801 id: "tool_id_2".into(),
1802 name: EchoTool::name().into(),
1803 raw_input: json!({"text": "test"}).to_string(),
1804 input: json!({"text": "test"}),
1805 is_input_complete: true,
1806 thought_signature: None,
1807 };
1808 fake_model.send_last_completion_stream_text_chunk("Hi!");
1809 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1810 permission_tool_use,
1811 ));
1812 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1813 echo_tool_use.clone(),
1814 ));
1815 fake_model.end_last_completion_stream();
1816 cx.run_until_parked();
1817
1818 // Ensure pending tools are skipped when building a request.
1819 let request = thread
1820 .read_with(cx, |thread, cx| {
1821 thread.build_completion_request(CompletionIntent::EditFile, cx)
1822 })
1823 .unwrap();
1824 assert_eq!(
1825 request.messages[1..],
1826 vec![
1827 LanguageModelRequestMessage {
1828 role: Role::User,
1829 content: vec!["Hey!".into()],
1830 cache: true
1831 },
1832 LanguageModelRequestMessage {
1833 role: Role::Assistant,
1834 content: vec![
1835 MessageContent::Text("Hi!".into()),
1836 MessageContent::ToolUse(echo_tool_use.clone())
1837 ],
1838 cache: false
1839 },
1840 LanguageModelRequestMessage {
1841 role: Role::User,
1842 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1843 tool_use_id: echo_tool_use.id.clone(),
1844 tool_name: echo_tool_use.name,
1845 is_error: false,
1846 content: "test".into(),
1847 output: Some("test".into())
1848 })],
1849 cache: false
1850 },
1851 ],
1852 );
1853}
1854
1855#[gpui::test]
1856async fn test_agent_connection(cx: &mut TestAppContext) {
1857 cx.update(settings::init);
1858 let templates = Templates::new();
1859
1860 // Initialize language model system with test provider
1861 cx.update(|cx| {
1862 gpui_tokio::init(cx);
1863
1864 let http_client = FakeHttpClient::with_404_response();
1865 let clock = Arc::new(clock::FakeSystemClock::new());
1866 let client = Client::new(clock, http_client, cx);
1867 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1868 language_model::init(client.clone(), cx);
1869 language_models::init(user_store, client.clone(), cx);
1870 LanguageModelRegistry::test(cx);
1871 });
1872 cx.executor().forbid_parking();
1873
1874 // Create a project for new_thread
1875 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1876 fake_fs.insert_tree(path!("/test"), json!({})).await;
1877 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1878 let cwd = Path::new("/test");
1879 let text_thread_store =
1880 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1881 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1882
1883 // Create agent and connection
1884 let agent = NativeAgent::new(
1885 project.clone(),
1886 history_store,
1887 templates.clone(),
1888 None,
1889 fake_fs.clone(),
1890 &mut cx.to_async(),
1891 )
1892 .await
1893 .unwrap();
1894 let connection = NativeAgentConnection(agent.clone());
1895
1896 // Create a thread using new_thread
1897 let connection_rc = Rc::new(connection.clone());
1898 let acp_thread = cx
1899 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1900 .await
1901 .expect("new_thread should succeed");
1902
1903 // Get the session_id from the AcpThread
1904 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1905
1906 // Test model_selector returns Some
1907 let selector_opt = connection.model_selector(&session_id);
1908 assert!(
1909 selector_opt.is_some(),
1910 "agent should always support ModelSelector"
1911 );
1912 let selector = selector_opt.unwrap();
1913
1914 // Test list_models
1915 let listed_models = cx
1916 .update(|cx| selector.list_models(cx))
1917 .await
1918 .expect("list_models should succeed");
1919 let AgentModelList::Grouped(listed_models) = listed_models else {
1920 panic!("Unexpected model list type");
1921 };
1922 assert!(!listed_models.is_empty(), "should have at least one model");
1923 assert_eq!(
1924 listed_models[&AgentModelGroupName("Fake".into())][0]
1925 .id
1926 .0
1927 .as_ref(),
1928 "fake/fake"
1929 );
1930
1931 // Test selected_model returns the default
1932 let model = cx
1933 .update(|cx| selector.selected_model(cx))
1934 .await
1935 .expect("selected_model should succeed");
1936 let model = cx
1937 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1938 .unwrap();
1939 let model = model.as_fake();
1940 assert_eq!(model.id().0, "fake", "should return default model");
1941
1942 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1943 cx.run_until_parked();
1944 model.send_last_completion_stream_text_chunk("def");
1945 cx.run_until_parked();
1946 acp_thread.read_with(cx, |thread, cx| {
1947 assert_eq!(
1948 thread.to_markdown(cx),
1949 indoc! {"
1950 ## User
1951
1952 abc
1953
1954 ## Assistant
1955
1956 def
1957
1958 "}
1959 )
1960 });
1961
1962 // Test cancel
1963 cx.update(|cx| connection.cancel(&session_id, cx));
1964 request.await.expect("prompt should fail gracefully");
1965
1966 // Ensure that dropping the ACP thread causes the native thread to be
1967 // dropped as well.
1968 cx.update(|_| drop(acp_thread));
1969 let result = cx
1970 .update(|cx| {
1971 connection.prompt(
1972 Some(acp_thread::UserMessageId::new()),
1973 acp::PromptRequest {
1974 session_id: session_id.clone(),
1975 prompt: vec!["ghi".into()],
1976 meta: None,
1977 },
1978 cx,
1979 )
1980 })
1981 .await;
1982 assert_eq!(
1983 result.as_ref().unwrap_err().to_string(),
1984 "Session not found",
1985 "unexpected result: {:?}",
1986 result
1987 );
1988}
1989
1990#[gpui::test]
1991async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1992 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1993 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1994 let fake_model = model.as_fake();
1995
1996 let mut events = thread
1997 .update(cx, |thread, cx| {
1998 thread.send(UserMessageId::new(), ["Think"], cx)
1999 })
2000 .unwrap();
2001 cx.run_until_parked();
2002
2003 // Simulate streaming partial input.
2004 let input = json!({});
2005 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2006 LanguageModelToolUse {
2007 id: "1".into(),
2008 name: ThinkingTool::name().into(),
2009 raw_input: input.to_string(),
2010 input,
2011 is_input_complete: false,
2012 thought_signature: None,
2013 },
2014 ));
2015
2016 // Input streaming completed
2017 let input = json!({ "content": "Thinking hard!" });
2018 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2019 LanguageModelToolUse {
2020 id: "1".into(),
2021 name: "thinking".into(),
2022 raw_input: input.to_string(),
2023 input,
2024 is_input_complete: true,
2025 thought_signature: None,
2026 },
2027 ));
2028 fake_model.end_last_completion_stream();
2029 cx.run_until_parked();
2030
2031 let tool_call = expect_tool_call(&mut events).await;
2032 assert_eq!(
2033 tool_call,
2034 acp::ToolCall {
2035 id: acp::ToolCallId("1".into()),
2036 title: "Thinking".into(),
2037 kind: acp::ToolKind::Think,
2038 status: acp::ToolCallStatus::Pending,
2039 content: vec![],
2040 locations: vec![],
2041 raw_input: Some(json!({})),
2042 raw_output: None,
2043 meta: Some(json!({ "tool_name": "thinking" })),
2044 }
2045 );
2046 let update = expect_tool_call_update_fields(&mut events).await;
2047 assert_eq!(
2048 update,
2049 acp::ToolCallUpdate {
2050 id: acp::ToolCallId("1".into()),
2051 fields: acp::ToolCallUpdateFields {
2052 title: Some("Thinking".into()),
2053 kind: Some(acp::ToolKind::Think),
2054 raw_input: Some(json!({ "content": "Thinking hard!" })),
2055 ..Default::default()
2056 },
2057 meta: None,
2058 }
2059 );
2060 let update = expect_tool_call_update_fields(&mut events).await;
2061 assert_eq!(
2062 update,
2063 acp::ToolCallUpdate {
2064 id: acp::ToolCallId("1".into()),
2065 fields: acp::ToolCallUpdateFields {
2066 status: Some(acp::ToolCallStatus::InProgress),
2067 ..Default::default()
2068 },
2069 meta: None,
2070 }
2071 );
2072 let update = expect_tool_call_update_fields(&mut events).await;
2073 assert_eq!(
2074 update,
2075 acp::ToolCallUpdate {
2076 id: acp::ToolCallId("1".into()),
2077 fields: acp::ToolCallUpdateFields {
2078 content: Some(vec!["Thinking hard!".into()]),
2079 ..Default::default()
2080 },
2081 meta: None,
2082 }
2083 );
2084 let update = expect_tool_call_update_fields(&mut events).await;
2085 assert_eq!(
2086 update,
2087 acp::ToolCallUpdate {
2088 id: acp::ToolCallId("1".into()),
2089 fields: acp::ToolCallUpdateFields {
2090 status: Some(acp::ToolCallStatus::Completed),
2091 raw_output: Some("Finished thinking.".into()),
2092 ..Default::default()
2093 },
2094 meta: None,
2095 }
2096 );
2097}
2098
2099#[gpui::test]
2100async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2101 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2102 let fake_model = model.as_fake();
2103
2104 let mut events = thread
2105 .update(cx, |thread, cx| {
2106 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2107 thread.send(UserMessageId::new(), ["Hello!"], cx)
2108 })
2109 .unwrap();
2110 cx.run_until_parked();
2111
2112 fake_model.send_last_completion_stream_text_chunk("Hey!");
2113 fake_model.end_last_completion_stream();
2114
2115 let mut retry_events = Vec::new();
2116 while let Some(Ok(event)) = events.next().await {
2117 match event {
2118 ThreadEvent::Retry(retry_status) => {
2119 retry_events.push(retry_status);
2120 }
2121 ThreadEvent::Stop(..) => break,
2122 _ => {}
2123 }
2124 }
2125
2126 assert_eq!(retry_events.len(), 0);
2127 thread.read_with(cx, |thread, _cx| {
2128 assert_eq!(
2129 thread.to_markdown(),
2130 indoc! {"
2131 ## User
2132
2133 Hello!
2134
2135 ## Assistant
2136
2137 Hey!
2138 "}
2139 )
2140 });
2141}
2142
2143#[gpui::test]
2144async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2145 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2146 let fake_model = model.as_fake();
2147
2148 let mut events = thread
2149 .update(cx, |thread, cx| {
2150 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2151 thread.send(UserMessageId::new(), ["Hello!"], cx)
2152 })
2153 .unwrap();
2154 cx.run_until_parked();
2155
2156 fake_model.send_last_completion_stream_text_chunk("Hey,");
2157 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2158 provider: LanguageModelProviderName::new("Anthropic"),
2159 retry_after: Some(Duration::from_secs(3)),
2160 });
2161 fake_model.end_last_completion_stream();
2162
2163 cx.executor().advance_clock(Duration::from_secs(3));
2164 cx.run_until_parked();
2165
2166 fake_model.send_last_completion_stream_text_chunk("there!");
2167 fake_model.end_last_completion_stream();
2168 cx.run_until_parked();
2169
2170 let mut retry_events = Vec::new();
2171 while let Some(Ok(event)) = events.next().await {
2172 match event {
2173 ThreadEvent::Retry(retry_status) => {
2174 retry_events.push(retry_status);
2175 }
2176 ThreadEvent::Stop(..) => break,
2177 _ => {}
2178 }
2179 }
2180
2181 assert_eq!(retry_events.len(), 1);
2182 assert!(matches!(
2183 retry_events[0],
2184 acp_thread::RetryStatus { attempt: 1, .. }
2185 ));
2186 thread.read_with(cx, |thread, _cx| {
2187 assert_eq!(
2188 thread.to_markdown(),
2189 indoc! {"
2190 ## User
2191
2192 Hello!
2193
2194 ## Assistant
2195
2196 Hey,
2197
2198 [resume]
2199
2200 ## Assistant
2201
2202 there!
2203 "}
2204 )
2205 });
2206}
2207
2208#[gpui::test]
2209async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2210 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2211 let fake_model = model.as_fake();
2212
2213 let events = thread
2214 .update(cx, |thread, cx| {
2215 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2216 thread.add_tool(EchoTool);
2217 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2218 })
2219 .unwrap();
2220 cx.run_until_parked();
2221
2222 let tool_use_1 = LanguageModelToolUse {
2223 id: "tool_1".into(),
2224 name: EchoTool::name().into(),
2225 raw_input: json!({"text": "test"}).to_string(),
2226 input: json!({"text": "test"}),
2227 is_input_complete: true,
2228 thought_signature: None,
2229 };
2230 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2231 tool_use_1.clone(),
2232 ));
2233 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2234 provider: LanguageModelProviderName::new("Anthropic"),
2235 retry_after: Some(Duration::from_secs(3)),
2236 });
2237 fake_model.end_last_completion_stream();
2238
2239 cx.executor().advance_clock(Duration::from_secs(3));
2240 let completion = fake_model.pending_completions().pop().unwrap();
2241 assert_eq!(
2242 completion.messages[1..],
2243 vec![
2244 LanguageModelRequestMessage {
2245 role: Role::User,
2246 content: vec!["Call the echo tool!".into()],
2247 cache: false
2248 },
2249 LanguageModelRequestMessage {
2250 role: Role::Assistant,
2251 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2252 cache: false
2253 },
2254 LanguageModelRequestMessage {
2255 role: Role::User,
2256 content: vec![language_model::MessageContent::ToolResult(
2257 LanguageModelToolResult {
2258 tool_use_id: tool_use_1.id.clone(),
2259 tool_name: tool_use_1.name.clone(),
2260 is_error: false,
2261 content: "test".into(),
2262 output: Some("test".into())
2263 }
2264 )],
2265 cache: true
2266 },
2267 ]
2268 );
2269
2270 fake_model.send_last_completion_stream_text_chunk("Done");
2271 fake_model.end_last_completion_stream();
2272 cx.run_until_parked();
2273 events.collect::<Vec<_>>().await;
2274 thread.read_with(cx, |thread, _cx| {
2275 assert_eq!(
2276 thread.last_message(),
2277 Some(Message::Agent(AgentMessage {
2278 content: vec![AgentMessageContent::Text("Done".into())],
2279 tool_results: IndexMap::default()
2280 }))
2281 );
2282 })
2283}
2284
2285#[gpui::test]
2286async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2287 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2288 let fake_model = model.as_fake();
2289
2290 let mut events = thread
2291 .update(cx, |thread, cx| {
2292 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2293 thread.send(UserMessageId::new(), ["Hello!"], cx)
2294 })
2295 .unwrap();
2296 cx.run_until_parked();
2297
2298 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2299 fake_model.send_last_completion_stream_error(
2300 LanguageModelCompletionError::ServerOverloaded {
2301 provider: LanguageModelProviderName::new("Anthropic"),
2302 retry_after: Some(Duration::from_secs(3)),
2303 },
2304 );
2305 fake_model.end_last_completion_stream();
2306 cx.executor().advance_clock(Duration::from_secs(3));
2307 cx.run_until_parked();
2308 }
2309
2310 let mut errors = Vec::new();
2311 let mut retry_events = Vec::new();
2312 while let Some(event) = events.next().await {
2313 match event {
2314 Ok(ThreadEvent::Retry(retry_status)) => {
2315 retry_events.push(retry_status);
2316 }
2317 Ok(ThreadEvent::Stop(..)) => break,
2318 Err(error) => errors.push(error),
2319 _ => {}
2320 }
2321 }
2322
2323 assert_eq!(
2324 retry_events.len(),
2325 crate::thread::MAX_RETRY_ATTEMPTS as usize
2326 );
2327 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2328 assert_eq!(retry_events[i].attempt, i + 1);
2329 }
2330 assert_eq!(errors.len(), 1);
2331 let error = errors[0]
2332 .downcast_ref::<LanguageModelCompletionError>()
2333 .unwrap();
2334 assert!(matches!(
2335 error,
2336 LanguageModelCompletionError::ServerOverloaded { .. }
2337 ));
2338}
2339
2340/// Filters out the stop events for asserting against in tests
2341fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2342 result_events
2343 .into_iter()
2344 .filter_map(|event| match event.unwrap() {
2345 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2346 _ => None,
2347 })
2348 .collect()
2349}
2350
2351struct ThreadTest {
2352 model: Arc<dyn LanguageModel>,
2353 thread: Entity<Thread>,
2354 project_context: Entity<ProjectContext>,
2355 context_server_store: Entity<ContextServerStore>,
2356 fs: Arc<FakeFs>,
2357}
2358
2359enum TestModel {
2360 Sonnet4,
2361 Fake,
2362}
2363
2364impl TestModel {
2365 fn id(&self) -> LanguageModelId {
2366 match self {
2367 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2368 TestModel::Fake => unreachable!(),
2369 }
2370 }
2371}
2372
2373async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2374 cx.executor().allow_parking();
2375
2376 let fs = FakeFs::new(cx.background_executor.clone());
2377 fs.create_dir(paths::settings_file().parent().unwrap())
2378 .await
2379 .unwrap();
2380 fs.insert_file(
2381 paths::settings_file(),
2382 json!({
2383 "agent": {
2384 "default_profile": "test-profile",
2385 "profiles": {
2386 "test-profile": {
2387 "name": "Test Profile",
2388 "tools": {
2389 EchoTool::name(): true,
2390 DelayTool::name(): true,
2391 WordListTool::name(): true,
2392 ToolRequiringPermission::name(): true,
2393 InfiniteTool::name(): true,
2394 ThinkingTool::name(): true,
2395 }
2396 }
2397 }
2398 }
2399 })
2400 .to_string()
2401 .into_bytes(),
2402 )
2403 .await;
2404
2405 cx.update(|cx| {
2406 settings::init(cx);
2407
2408 match model {
2409 TestModel::Fake => {}
2410 TestModel::Sonnet4 => {
2411 gpui_tokio::init(cx);
2412 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2413 cx.set_http_client(Arc::new(http_client));
2414 let client = Client::production(cx);
2415 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2416 language_model::init(client.clone(), cx);
2417 language_models::init(user_store, client.clone(), cx);
2418 }
2419 };
2420
2421 watch_settings(fs.clone(), cx);
2422 });
2423
2424 let templates = Templates::new();
2425
2426 fs.insert_tree(path!("/test"), json!({})).await;
2427 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2428
2429 let model = cx
2430 .update(|cx| {
2431 if let TestModel::Fake = model {
2432 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2433 } else {
2434 let model_id = model.id();
2435 let models = LanguageModelRegistry::read_global(cx);
2436 let model = models
2437 .available_models(cx)
2438 .find(|model| model.id() == model_id)
2439 .unwrap();
2440
2441 let provider = models.provider(&model.provider_id()).unwrap();
2442 let authenticated = provider.authenticate(cx);
2443
2444 cx.spawn(async move |_cx| {
2445 authenticated.await.unwrap();
2446 model
2447 })
2448 }
2449 })
2450 .await;
2451
2452 let project_context = cx.new(|_cx| ProjectContext::default());
2453 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2454 let context_server_registry =
2455 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2456 let thread = cx.new(|cx| {
2457 Thread::new(
2458 project,
2459 project_context.clone(),
2460 context_server_registry,
2461 templates,
2462 Some(model.clone()),
2463 cx,
2464 )
2465 });
2466 ThreadTest {
2467 model,
2468 thread,
2469 project_context,
2470 context_server_store,
2471 fs,
2472 }
2473}
2474
2475#[cfg(test)]
2476#[ctor::ctor]
2477fn init_logger() {
2478 if std::env::var("RUST_LOG").is_ok() {
2479 env_logger::init();
2480 }
2481}
2482
2483fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2484 let fs = fs.clone();
2485 cx.spawn({
2486 async move |cx| {
2487 let mut new_settings_content_rx = settings::watch_config_file(
2488 cx.background_executor(),
2489 fs,
2490 paths::settings_file().clone(),
2491 );
2492
2493 while let Some(new_settings_content) = new_settings_content_rx.next().await {
2494 cx.update(|cx| {
2495 SettingsStore::update_global(cx, |settings, cx| {
2496 settings.set_user_settings(&new_settings_content, cx)
2497 })
2498 })
2499 .ok();
2500 }
2501 }
2502 })
2503 .detach();
2504}
2505
2506fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2507 completion
2508 .tools
2509 .iter()
2510 .map(|tool| tool.name.clone())
2511 .collect()
2512}
2513
2514fn setup_context_server(
2515 name: &'static str,
2516 tools: Vec<context_server::types::Tool>,
2517 context_server_store: &Entity<ContextServerStore>,
2518 cx: &mut TestAppContext,
2519) -> mpsc::UnboundedReceiver<(
2520 context_server::types::CallToolParams,
2521 oneshot::Sender<context_server::types::CallToolResponse>,
2522)> {
2523 cx.update(|cx| {
2524 let mut settings = ProjectSettings::get_global(cx).clone();
2525 settings.context_servers.insert(
2526 name.into(),
2527 project::project_settings::ContextServerSettings::Custom {
2528 enabled: true,
2529 command: ContextServerCommand {
2530 path: "somebinary".into(),
2531 args: Vec::new(),
2532 env: None,
2533 timeout: None,
2534 },
2535 },
2536 );
2537 ProjectSettings::override_global(settings, cx);
2538 });
2539
2540 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2541 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2542 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2543 context_server::types::InitializeResponse {
2544 protocol_version: context_server::types::ProtocolVersion(
2545 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2546 ),
2547 server_info: context_server::types::Implementation {
2548 name: name.into(),
2549 version: "1.0.0".to_string(),
2550 },
2551 capabilities: context_server::types::ServerCapabilities {
2552 tools: Some(context_server::types::ToolsCapabilities {
2553 list_changed: Some(true),
2554 }),
2555 ..Default::default()
2556 },
2557 meta: None,
2558 }
2559 })
2560 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2561 let tools = tools.clone();
2562 async move {
2563 context_server::types::ListToolsResponse {
2564 tools,
2565 next_cursor: None,
2566 meta: None,
2567 }
2568 }
2569 })
2570 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2571 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2572 async move {
2573 let (response_tx, response_rx) = oneshot::channel();
2574 mcp_tool_calls_tx
2575 .unbounded_send((params, response_tx))
2576 .unwrap();
2577 response_rx.await.unwrap()
2578 }
2579 });
2580 context_server_store.update(cx, |store, cx| {
2581 store.start_server(
2582 Arc::new(ContextServer::new(
2583 ContextServerId(name.into()),
2584 Arc::new(fake_transport),
2585 )),
2586 cx,
2587 );
2588 });
2589 cx.run_until_parked();
2590 mcp_tool_calls_rx
2591}