1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use fs::{FakeFs, Fs};
11use futures::{
12 StreamExt,
13 channel::{
14 mpsc::{self, UnboundedReceiver},
15 oneshot,
16 },
17};
18use gpui::{
19 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
20};
21use indoc::indoc;
22use language_model::{
23 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
24 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
25 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
26 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
27};
28use pretty_assertions::assert_eq;
29use project::{
30 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
31};
32use prompt_store::ProjectContext;
33use reqwest_client::ReqwestClient;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use serde_json::json;
37use settings::{Settings, SettingsStore};
38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
39use util::path;
40
41mod test_tools;
42use test_tools::*;
43
44#[gpui::test]
45async fn test_echo(cx: &mut TestAppContext) {
46 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
47 let fake_model = model.as_fake();
48
49 let events = thread
50 .update(cx, |thread, cx| {
51 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
52 })
53 .unwrap();
54 cx.run_until_parked();
55 fake_model.send_last_completion_stream_text_chunk("Hello");
56 fake_model
57 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
58 fake_model.end_last_completion_stream();
59
60 let events = events.collect().await;
61 thread.update(cx, |thread, _cx| {
62 assert_eq!(
63 thread.last_message().unwrap().to_markdown(),
64 indoc! {"
65 ## Assistant
66
67 Hello
68 "}
69 )
70 });
71 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
72}
73
74#[gpui::test]
75async fn test_thinking(cx: &mut TestAppContext) {
76 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
77 let fake_model = model.as_fake();
78
79 let events = thread
80 .update(cx, |thread, cx| {
81 thread.send(
82 UserMessageId::new(),
83 [indoc! {"
84 Testing:
85
86 Generate a thinking step where you just think the word 'Think',
87 and have your final answer be 'Hello'
88 "}],
89 cx,
90 )
91 })
92 .unwrap();
93 cx.run_until_parked();
94 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
95 text: "Think".to_string(),
96 signature: None,
97 });
98 fake_model.send_last_completion_stream_text_chunk("Hello");
99 fake_model
100 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
101 fake_model.end_last_completion_stream();
102
103 let events = events.collect().await;
104 thread.update(cx, |thread, _cx| {
105 assert_eq!(
106 thread.last_message().unwrap().to_markdown(),
107 indoc! {"
108 ## Assistant
109
110 <think>Think</think>
111 Hello
112 "}
113 )
114 });
115 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
116}
117
118#[gpui::test]
119async fn test_system_prompt(cx: &mut TestAppContext) {
120 let ThreadTest {
121 model,
122 thread,
123 project_context,
124 ..
125 } = setup(cx, TestModel::Fake).await;
126 let fake_model = model.as_fake();
127
128 project_context.update(cx, |project_context, _cx| {
129 project_context.shell = "test-shell".into()
130 });
131 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
132 thread
133 .update(cx, |thread, cx| {
134 thread.send(UserMessageId::new(), ["abc"], cx)
135 })
136 .unwrap();
137 cx.run_until_parked();
138 let mut pending_completions = fake_model.pending_completions();
139 assert_eq!(
140 pending_completions.len(),
141 1,
142 "unexpected pending completions: {:?}",
143 pending_completions
144 );
145
146 let pending_completion = pending_completions.pop().unwrap();
147 assert_eq!(pending_completion.messages[0].role, Role::System);
148
149 let system_message = &pending_completion.messages[0];
150 let system_prompt = system_message.content[0].to_str().unwrap();
151 assert!(
152 system_prompt.contains("test-shell"),
153 "unexpected system message: {:?}",
154 system_message
155 );
156 assert!(
157 system_prompt.contains("## Fixing Diagnostics"),
158 "unexpected system message: {:?}",
159 system_message
160 );
161}
162
163#[gpui::test]
164async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
165 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
166 let fake_model = model.as_fake();
167
168 thread
169 .update(cx, |thread, cx| {
170 thread.send(UserMessageId::new(), ["abc"], cx)
171 })
172 .unwrap();
173 cx.run_until_parked();
174 let mut pending_completions = fake_model.pending_completions();
175 assert_eq!(
176 pending_completions.len(),
177 1,
178 "unexpected pending completions: {:?}",
179 pending_completions
180 );
181
182 let pending_completion = pending_completions.pop().unwrap();
183 assert_eq!(pending_completion.messages[0].role, Role::System);
184
185 let system_message = &pending_completion.messages[0];
186 let system_prompt = system_message.content[0].to_str().unwrap();
187 assert!(
188 !system_prompt.contains("## Tool Use"),
189 "unexpected system message: {:?}",
190 system_message
191 );
192 assert!(
193 !system_prompt.contains("## Fixing Diagnostics"),
194 "unexpected system message: {:?}",
195 system_message
196 );
197}
198
199#[gpui::test]
200async fn test_prompt_caching(cx: &mut TestAppContext) {
201 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
202 let fake_model = model.as_fake();
203
204 // Send initial user message and verify it's cached
205 thread
206 .update(cx, |thread, cx| {
207 thread.send(UserMessageId::new(), ["Message 1"], cx)
208 })
209 .unwrap();
210 cx.run_until_parked();
211
212 let completion = fake_model.pending_completions().pop().unwrap();
213 assert_eq!(
214 completion.messages[1..],
215 vec![LanguageModelRequestMessage {
216 role: Role::User,
217 content: vec!["Message 1".into()],
218 cache: true
219 }]
220 );
221 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
222 "Response to Message 1".into(),
223 ));
224 fake_model.end_last_completion_stream();
225 cx.run_until_parked();
226
227 // Send another user message and verify only the latest is cached
228 thread
229 .update(cx, |thread, cx| {
230 thread.send(UserMessageId::new(), ["Message 2"], cx)
231 })
232 .unwrap();
233 cx.run_until_parked();
234
235 let completion = fake_model.pending_completions().pop().unwrap();
236 assert_eq!(
237 completion.messages[1..],
238 vec![
239 LanguageModelRequestMessage {
240 role: Role::User,
241 content: vec!["Message 1".into()],
242 cache: false
243 },
244 LanguageModelRequestMessage {
245 role: Role::Assistant,
246 content: vec!["Response to Message 1".into()],
247 cache: false
248 },
249 LanguageModelRequestMessage {
250 role: Role::User,
251 content: vec!["Message 2".into()],
252 cache: true
253 }
254 ]
255 );
256 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
257 "Response to Message 2".into(),
258 ));
259 fake_model.end_last_completion_stream();
260 cx.run_until_parked();
261
262 // Simulate a tool call and verify that the latest tool result is cached
263 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
264 thread
265 .update(cx, |thread, cx| {
266 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
267 })
268 .unwrap();
269 cx.run_until_parked();
270
271 let tool_use = LanguageModelToolUse {
272 id: "tool_1".into(),
273 name: EchoTool::name().into(),
274 raw_input: json!({"text": "test"}).to_string(),
275 input: json!({"text": "test"}),
276 is_input_complete: true,
277 };
278 fake_model
279 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
280 fake_model.end_last_completion_stream();
281 cx.run_until_parked();
282
283 let completion = fake_model.pending_completions().pop().unwrap();
284 let tool_result = LanguageModelToolResult {
285 tool_use_id: "tool_1".into(),
286 tool_name: EchoTool::name().into(),
287 is_error: false,
288 content: "test".into(),
289 output: Some("test".into()),
290 };
291 assert_eq!(
292 completion.messages[1..],
293 vec![
294 LanguageModelRequestMessage {
295 role: Role::User,
296 content: vec!["Message 1".into()],
297 cache: false
298 },
299 LanguageModelRequestMessage {
300 role: Role::Assistant,
301 content: vec!["Response to Message 1".into()],
302 cache: false
303 },
304 LanguageModelRequestMessage {
305 role: Role::User,
306 content: vec!["Message 2".into()],
307 cache: false
308 },
309 LanguageModelRequestMessage {
310 role: Role::Assistant,
311 content: vec!["Response to Message 2".into()],
312 cache: false
313 },
314 LanguageModelRequestMessage {
315 role: Role::User,
316 content: vec!["Use the echo tool".into()],
317 cache: false
318 },
319 LanguageModelRequestMessage {
320 role: Role::Assistant,
321 content: vec![MessageContent::ToolUse(tool_use)],
322 cache: false
323 },
324 LanguageModelRequestMessage {
325 role: Role::User,
326 content: vec![MessageContent::ToolResult(tool_result)],
327 cache: true
328 }
329 ]
330 );
331}
332
333#[gpui::test]
334#[cfg_attr(not(feature = "e2e"), ignore)]
335async fn test_basic_tool_calls(cx: &mut TestAppContext) {
336 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
337
338 // Test a tool call that's likely to complete *before* streaming stops.
339 let events = thread
340 .update(cx, |thread, cx| {
341 thread.add_tool(EchoTool);
342 thread.send(
343 UserMessageId::new(),
344 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
345 cx,
346 )
347 })
348 .unwrap()
349 .collect()
350 .await;
351 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
352
353 // Test a tool calls that's likely to complete *after* streaming stops.
354 let events = thread
355 .update(cx, |thread, cx| {
356 thread.remove_tool(&EchoTool::name());
357 thread.add_tool(DelayTool);
358 thread.send(
359 UserMessageId::new(),
360 [
361 "Now call the delay tool with 200ms.",
362 "When the timer goes off, then you echo the output of the tool.",
363 ],
364 cx,
365 )
366 })
367 .unwrap()
368 .collect()
369 .await;
370 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
371 thread.update(cx, |thread, _cx| {
372 assert!(
373 thread
374 .last_message()
375 .unwrap()
376 .as_agent_message()
377 .unwrap()
378 .content
379 .iter()
380 .any(|content| {
381 if let AgentMessageContent::Text(text) = content {
382 text.contains("Ding")
383 } else {
384 false
385 }
386 }),
387 "{}",
388 thread.to_markdown()
389 );
390 });
391}
392
393#[gpui::test]
394#[cfg_attr(not(feature = "e2e"), ignore)]
395async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
396 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
397
398 // Test a tool call that's likely to complete *before* streaming stops.
399 let mut events = thread
400 .update(cx, |thread, cx| {
401 thread.add_tool(WordListTool);
402 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
403 })
404 .unwrap();
405
406 let mut saw_partial_tool_use = false;
407 while let Some(event) = events.next().await {
408 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
409 thread.update(cx, |thread, _cx| {
410 // Look for a tool use in the thread's last message
411 let message = thread.last_message().unwrap();
412 let agent_message = message.as_agent_message().unwrap();
413 let last_content = agent_message.content.last().unwrap();
414 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
415 assert_eq!(last_tool_use.name.as_ref(), "word_list");
416 if tool_call.status == acp::ToolCallStatus::Pending {
417 if !last_tool_use.is_input_complete
418 && last_tool_use.input.get("g").is_none()
419 {
420 saw_partial_tool_use = true;
421 }
422 } else {
423 last_tool_use
424 .input
425 .get("a")
426 .expect("'a' has streamed because input is now complete");
427 last_tool_use
428 .input
429 .get("g")
430 .expect("'g' has streamed because input is now complete");
431 }
432 } else {
433 panic!("last content should be a tool use");
434 }
435 });
436 }
437 }
438
439 assert!(
440 saw_partial_tool_use,
441 "should see at least one partially streamed tool use in the history"
442 );
443}
444
445#[gpui::test]
446async fn test_tool_authorization(cx: &mut TestAppContext) {
447 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
448 let fake_model = model.as_fake();
449
450 let mut events = thread
451 .update(cx, |thread, cx| {
452 thread.add_tool(ToolRequiringPermission);
453 thread.send(UserMessageId::new(), ["abc"], cx)
454 })
455 .unwrap();
456 cx.run_until_parked();
457 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
458 LanguageModelToolUse {
459 id: "tool_id_1".into(),
460 name: ToolRequiringPermission::name().into(),
461 raw_input: "{}".into(),
462 input: json!({}),
463 is_input_complete: true,
464 },
465 ));
466 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
467 LanguageModelToolUse {
468 id: "tool_id_2".into(),
469 name: ToolRequiringPermission::name().into(),
470 raw_input: "{}".into(),
471 input: json!({}),
472 is_input_complete: true,
473 },
474 ));
475 fake_model.end_last_completion_stream();
476 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
477 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
478
479 // Approve the first
480 tool_call_auth_1
481 .response
482 .send(tool_call_auth_1.options[1].id.clone())
483 .unwrap();
484 cx.run_until_parked();
485
486 // Reject the second
487 tool_call_auth_2
488 .response
489 .send(tool_call_auth_1.options[2].id.clone())
490 .unwrap();
491 cx.run_until_parked();
492
493 let completion = fake_model.pending_completions().pop().unwrap();
494 let message = completion.messages.last().unwrap();
495 assert_eq!(
496 message.content,
497 vec![
498 language_model::MessageContent::ToolResult(LanguageModelToolResult {
499 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
500 tool_name: ToolRequiringPermission::name().into(),
501 is_error: false,
502 content: "Allowed".into(),
503 output: Some("Allowed".into())
504 }),
505 language_model::MessageContent::ToolResult(LanguageModelToolResult {
506 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
507 tool_name: ToolRequiringPermission::name().into(),
508 is_error: true,
509 content: "Permission to run tool denied by user".into(),
510 output: Some("Permission to run tool denied by user".into())
511 })
512 ]
513 );
514
515 // Simulate yet another tool call.
516 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
517 LanguageModelToolUse {
518 id: "tool_id_3".into(),
519 name: ToolRequiringPermission::name().into(),
520 raw_input: "{}".into(),
521 input: json!({}),
522 is_input_complete: true,
523 },
524 ));
525 fake_model.end_last_completion_stream();
526
527 // Respond by always allowing tools.
528 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
529 tool_call_auth_3
530 .response
531 .send(tool_call_auth_3.options[0].id.clone())
532 .unwrap();
533 cx.run_until_parked();
534 let completion = fake_model.pending_completions().pop().unwrap();
535 let message = completion.messages.last().unwrap();
536 assert_eq!(
537 message.content,
538 vec![language_model::MessageContent::ToolResult(
539 LanguageModelToolResult {
540 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
541 tool_name: ToolRequiringPermission::name().into(),
542 is_error: false,
543 content: "Allowed".into(),
544 output: Some("Allowed".into())
545 }
546 )]
547 );
548
549 // Simulate a final tool call, ensuring we don't trigger authorization.
550 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
551 LanguageModelToolUse {
552 id: "tool_id_4".into(),
553 name: ToolRequiringPermission::name().into(),
554 raw_input: "{}".into(),
555 input: json!({}),
556 is_input_complete: true,
557 },
558 ));
559 fake_model.end_last_completion_stream();
560 cx.run_until_parked();
561 let completion = fake_model.pending_completions().pop().unwrap();
562 let message = completion.messages.last().unwrap();
563 assert_eq!(
564 message.content,
565 vec![language_model::MessageContent::ToolResult(
566 LanguageModelToolResult {
567 tool_use_id: "tool_id_4".into(),
568 tool_name: ToolRequiringPermission::name().into(),
569 is_error: false,
570 content: "Allowed".into(),
571 output: Some("Allowed".into())
572 }
573 )]
574 );
575}
576
577#[gpui::test]
578async fn test_tool_hallucination(cx: &mut TestAppContext) {
579 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
580 let fake_model = model.as_fake();
581
582 let mut events = thread
583 .update(cx, |thread, cx| {
584 thread.send(UserMessageId::new(), ["abc"], cx)
585 })
586 .unwrap();
587 cx.run_until_parked();
588 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
589 LanguageModelToolUse {
590 id: "tool_id_1".into(),
591 name: "nonexistent_tool".into(),
592 raw_input: "{}".into(),
593 input: json!({}),
594 is_input_complete: true,
595 },
596 ));
597 fake_model.end_last_completion_stream();
598
599 let tool_call = expect_tool_call(&mut events).await;
600 assert_eq!(tool_call.title, "nonexistent_tool");
601 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
602 let update = expect_tool_call_update_fields(&mut events).await;
603 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
604}
605
606#[gpui::test]
607async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
608 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
609 let fake_model = model.as_fake();
610
611 let events = thread
612 .update(cx, |thread, cx| {
613 thread.add_tool(EchoTool);
614 thread.send(UserMessageId::new(), ["abc"], cx)
615 })
616 .unwrap();
617 cx.run_until_parked();
618 let tool_use = LanguageModelToolUse {
619 id: "tool_id_1".into(),
620 name: EchoTool::name().into(),
621 raw_input: "{}".into(),
622 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
623 is_input_complete: true,
624 };
625 fake_model
626 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
627 fake_model.end_last_completion_stream();
628
629 cx.run_until_parked();
630 let completion = fake_model.pending_completions().pop().unwrap();
631 let tool_result = LanguageModelToolResult {
632 tool_use_id: "tool_id_1".into(),
633 tool_name: EchoTool::name().into(),
634 is_error: false,
635 content: "def".into(),
636 output: Some("def".into()),
637 };
638 assert_eq!(
639 completion.messages[1..],
640 vec![
641 LanguageModelRequestMessage {
642 role: Role::User,
643 content: vec!["abc".into()],
644 cache: false
645 },
646 LanguageModelRequestMessage {
647 role: Role::Assistant,
648 content: vec![MessageContent::ToolUse(tool_use.clone())],
649 cache: false
650 },
651 LanguageModelRequestMessage {
652 role: Role::User,
653 content: vec![MessageContent::ToolResult(tool_result.clone())],
654 cache: true
655 },
656 ]
657 );
658
659 // Simulate reaching tool use limit.
660 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
661 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
662 ));
663 fake_model.end_last_completion_stream();
664 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
665 assert!(
666 last_event
667 .unwrap_err()
668 .is::<language_model::ToolUseLimitReachedError>()
669 );
670
671 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
672 cx.run_until_parked();
673 let completion = fake_model.pending_completions().pop().unwrap();
674 assert_eq!(
675 completion.messages[1..],
676 vec![
677 LanguageModelRequestMessage {
678 role: Role::User,
679 content: vec!["abc".into()],
680 cache: false
681 },
682 LanguageModelRequestMessage {
683 role: Role::Assistant,
684 content: vec![MessageContent::ToolUse(tool_use)],
685 cache: false
686 },
687 LanguageModelRequestMessage {
688 role: Role::User,
689 content: vec![MessageContent::ToolResult(tool_result)],
690 cache: false
691 },
692 LanguageModelRequestMessage {
693 role: Role::User,
694 content: vec!["Continue where you left off".into()],
695 cache: true
696 }
697 ]
698 );
699
700 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
701 fake_model.end_last_completion_stream();
702 events.collect::<Vec<_>>().await;
703 thread.read_with(cx, |thread, _cx| {
704 assert_eq!(
705 thread.last_message().unwrap().to_markdown(),
706 indoc! {"
707 ## Assistant
708
709 Done
710 "}
711 )
712 });
713}
714
715#[gpui::test]
716async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
717 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
718 let fake_model = model.as_fake();
719
720 let events = thread
721 .update(cx, |thread, cx| {
722 thread.add_tool(EchoTool);
723 thread.send(UserMessageId::new(), ["abc"], cx)
724 })
725 .unwrap();
726 cx.run_until_parked();
727
728 let tool_use = LanguageModelToolUse {
729 id: "tool_id_1".into(),
730 name: EchoTool::name().into(),
731 raw_input: "{}".into(),
732 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
733 is_input_complete: true,
734 };
735 let tool_result = LanguageModelToolResult {
736 tool_use_id: "tool_id_1".into(),
737 tool_name: EchoTool::name().into(),
738 is_error: false,
739 content: "def".into(),
740 output: Some("def".into()),
741 };
742 fake_model
743 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
744 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
745 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
746 ));
747 fake_model.end_last_completion_stream();
748 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
749 assert!(
750 last_event
751 .unwrap_err()
752 .is::<language_model::ToolUseLimitReachedError>()
753 );
754
755 thread
756 .update(cx, |thread, cx| {
757 thread.send(UserMessageId::new(), vec!["ghi"], cx)
758 })
759 .unwrap();
760 cx.run_until_parked();
761 let completion = fake_model.pending_completions().pop().unwrap();
762 assert_eq!(
763 completion.messages[1..],
764 vec![
765 LanguageModelRequestMessage {
766 role: Role::User,
767 content: vec!["abc".into()],
768 cache: false
769 },
770 LanguageModelRequestMessage {
771 role: Role::Assistant,
772 content: vec![MessageContent::ToolUse(tool_use)],
773 cache: false
774 },
775 LanguageModelRequestMessage {
776 role: Role::User,
777 content: vec![MessageContent::ToolResult(tool_result)],
778 cache: false
779 },
780 LanguageModelRequestMessage {
781 role: Role::User,
782 content: vec!["ghi".into()],
783 cache: true
784 }
785 ]
786 );
787}
788
789async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
790 let event = events
791 .next()
792 .await
793 .expect("no tool call authorization event received")
794 .unwrap();
795 match event {
796 ThreadEvent::ToolCall(tool_call) => tool_call,
797 event => {
798 panic!("Unexpected event {event:?}");
799 }
800 }
801}
802
803async fn expect_tool_call_update_fields(
804 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
805) -> acp::ToolCallUpdate {
806 let event = events
807 .next()
808 .await
809 .expect("no tool call authorization event received")
810 .unwrap();
811 match event {
812 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
813 event => {
814 panic!("Unexpected event {event:?}");
815 }
816 }
817}
818
819async fn next_tool_call_authorization(
820 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
821) -> ToolCallAuthorization {
822 loop {
823 let event = events
824 .next()
825 .await
826 .expect("no tool call authorization event received")
827 .unwrap();
828 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
829 let permission_kinds = tool_call_authorization
830 .options
831 .iter()
832 .map(|o| o.kind)
833 .collect::<Vec<_>>();
834 assert_eq!(
835 permission_kinds,
836 vec![
837 acp::PermissionOptionKind::AllowAlways,
838 acp::PermissionOptionKind::AllowOnce,
839 acp::PermissionOptionKind::RejectOnce,
840 ]
841 );
842 return tool_call_authorization;
843 }
844 }
845}
846
847#[gpui::test]
848#[cfg_attr(not(feature = "e2e"), ignore)]
849async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
850 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
851
852 // Test concurrent tool calls with different delay times
853 let events = thread
854 .update(cx, |thread, cx| {
855 thread.add_tool(DelayTool);
856 thread.send(
857 UserMessageId::new(),
858 [
859 "Call the delay tool twice in the same message.",
860 "Once with 100ms. Once with 300ms.",
861 "When both timers are complete, describe the outputs.",
862 ],
863 cx,
864 )
865 })
866 .unwrap()
867 .collect()
868 .await;
869
870 let stop_reasons = stop_events(events);
871 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
872
873 thread.update(cx, |thread, _cx| {
874 let last_message = thread.last_message().unwrap();
875 let agent_message = last_message.as_agent_message().unwrap();
876 let text = agent_message
877 .content
878 .iter()
879 .filter_map(|content| {
880 if let AgentMessageContent::Text(text) = content {
881 Some(text.as_str())
882 } else {
883 None
884 }
885 })
886 .collect::<String>();
887
888 assert!(text.contains("Ding"));
889 });
890}
891
892#[gpui::test]
893async fn test_profiles(cx: &mut TestAppContext) {
894 let ThreadTest {
895 model, thread, fs, ..
896 } = setup(cx, TestModel::Fake).await;
897 let fake_model = model.as_fake();
898
899 thread.update(cx, |thread, _cx| {
900 thread.add_tool(DelayTool);
901 thread.add_tool(EchoTool);
902 thread.add_tool(InfiniteTool);
903 });
904
905 // Override profiles and wait for settings to be loaded.
906 fs.insert_file(
907 paths::settings_file(),
908 json!({
909 "agent": {
910 "profiles": {
911 "test-1": {
912 "name": "Test Profile 1",
913 "tools": {
914 EchoTool::name(): true,
915 DelayTool::name(): true,
916 }
917 },
918 "test-2": {
919 "name": "Test Profile 2",
920 "tools": {
921 InfiniteTool::name(): true,
922 }
923 }
924 }
925 }
926 })
927 .to_string()
928 .into_bytes(),
929 )
930 .await;
931 cx.run_until_parked();
932
933 // Test that test-1 profile (default) has echo and delay tools
934 thread
935 .update(cx, |thread, cx| {
936 thread.set_profile(AgentProfileId("test-1".into()));
937 thread.send(UserMessageId::new(), ["test"], cx)
938 })
939 .unwrap();
940 cx.run_until_parked();
941
942 let mut pending_completions = fake_model.pending_completions();
943 assert_eq!(pending_completions.len(), 1);
944 let completion = pending_completions.pop().unwrap();
945 let tool_names: Vec<String> = completion
946 .tools
947 .iter()
948 .map(|tool| tool.name.clone())
949 .collect();
950 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
951 fake_model.end_last_completion_stream();
952
953 // Switch to test-2 profile, and verify that it has only the infinite tool.
954 thread
955 .update(cx, |thread, cx| {
956 thread.set_profile(AgentProfileId("test-2".into()));
957 thread.send(UserMessageId::new(), ["test2"], cx)
958 })
959 .unwrap();
960 cx.run_until_parked();
961 let mut pending_completions = fake_model.pending_completions();
962 assert_eq!(pending_completions.len(), 1);
963 let completion = pending_completions.pop().unwrap();
964 let tool_names: Vec<String> = completion
965 .tools
966 .iter()
967 .map(|tool| tool.name.clone())
968 .collect();
969 assert_eq!(tool_names, vec![InfiniteTool::name()]);
970}
971
972#[gpui::test]
973async fn test_mcp_tools(cx: &mut TestAppContext) {
974 let ThreadTest {
975 model,
976 thread,
977 context_server_store,
978 fs,
979 ..
980 } = setup(cx, TestModel::Fake).await;
981 let fake_model = model.as_fake();
982
983 // Override profiles and wait for settings to be loaded.
984 fs.insert_file(
985 paths::settings_file(),
986 json!({
987 "agent": {
988 "always_allow_tool_actions": true,
989 "profiles": {
990 "test": {
991 "name": "Test Profile",
992 "enable_all_context_servers": true,
993 "tools": {
994 EchoTool::name(): true,
995 }
996 },
997 }
998 }
999 })
1000 .to_string()
1001 .into_bytes(),
1002 )
1003 .await;
1004 cx.run_until_parked();
1005 thread.update(cx, |thread, _| {
1006 thread.set_profile(AgentProfileId("test".into()))
1007 });
1008
1009 let mut mcp_tool_calls = setup_context_server(
1010 "test_server",
1011 vec![context_server::types::Tool {
1012 name: "echo".into(),
1013 description: None,
1014 input_schema: serde_json::to_value(EchoTool::input_schema(
1015 LanguageModelToolSchemaFormat::JsonSchema,
1016 ))
1017 .unwrap(),
1018 output_schema: None,
1019 annotations: None,
1020 }],
1021 &context_server_store,
1022 cx,
1023 );
1024
1025 let events = thread.update(cx, |thread, cx| {
1026 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1027 });
1028 cx.run_until_parked();
1029
1030 // Simulate the model calling the MCP tool.
1031 let completion = fake_model.pending_completions().pop().unwrap();
1032 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1033 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1034 LanguageModelToolUse {
1035 id: "tool_1".into(),
1036 name: "echo".into(),
1037 raw_input: json!({"text": "test"}).to_string(),
1038 input: json!({"text": "test"}),
1039 is_input_complete: true,
1040 },
1041 ));
1042 fake_model.end_last_completion_stream();
1043 cx.run_until_parked();
1044
1045 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1046 assert_eq!(tool_call_params.name, "echo");
1047 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1048 tool_call_response
1049 .send(context_server::types::CallToolResponse {
1050 content: vec![context_server::types::ToolResponseContent::Text {
1051 text: "test".into(),
1052 }],
1053 is_error: None,
1054 meta: None,
1055 structured_content: None,
1056 })
1057 .unwrap();
1058 cx.run_until_parked();
1059
1060 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1061 fake_model.send_last_completion_stream_text_chunk("Done!");
1062 fake_model.end_last_completion_stream();
1063 events.collect::<Vec<_>>().await;
1064
1065 // Send again after adding the echo tool, ensuring the name collision is resolved.
1066 let events = thread.update(cx, |thread, cx| {
1067 thread.add_tool(EchoTool);
1068 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1069 });
1070 cx.run_until_parked();
1071 let completion = fake_model.pending_completions().pop().unwrap();
1072 assert_eq!(
1073 tool_names_for_completion(&completion),
1074 vec!["echo", "test_server_echo"]
1075 );
1076 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1077 LanguageModelToolUse {
1078 id: "tool_2".into(),
1079 name: "test_server_echo".into(),
1080 raw_input: json!({"text": "mcp"}).to_string(),
1081 input: json!({"text": "mcp"}),
1082 is_input_complete: true,
1083 },
1084 ));
1085 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1086 LanguageModelToolUse {
1087 id: "tool_3".into(),
1088 name: "echo".into(),
1089 raw_input: json!({"text": "native"}).to_string(),
1090 input: json!({"text": "native"}),
1091 is_input_complete: true,
1092 },
1093 ));
1094 fake_model.end_last_completion_stream();
1095 cx.run_until_parked();
1096
1097 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1098 assert_eq!(tool_call_params.name, "echo");
1099 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1100 tool_call_response
1101 .send(context_server::types::CallToolResponse {
1102 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1103 is_error: None,
1104 meta: None,
1105 structured_content: None,
1106 })
1107 .unwrap();
1108 cx.run_until_parked();
1109
1110 // Ensure the tool results were inserted with the correct names.
1111 let completion = fake_model.pending_completions().pop().unwrap();
1112 assert_eq!(
1113 completion.messages.last().unwrap().content,
1114 vec![
1115 MessageContent::ToolResult(LanguageModelToolResult {
1116 tool_use_id: "tool_3".into(),
1117 tool_name: "echo".into(),
1118 is_error: false,
1119 content: "native".into(),
1120 output: Some("native".into()),
1121 },),
1122 MessageContent::ToolResult(LanguageModelToolResult {
1123 tool_use_id: "tool_2".into(),
1124 tool_name: "test_server_echo".into(),
1125 is_error: false,
1126 content: "mcp".into(),
1127 output: Some("mcp".into()),
1128 },),
1129 ]
1130 );
1131 fake_model.end_last_completion_stream();
1132 events.collect::<Vec<_>>().await;
1133}
1134
1135#[gpui::test]
1136async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1137 let ThreadTest {
1138 model,
1139 thread,
1140 context_server_store,
1141 fs,
1142 ..
1143 } = setup(cx, TestModel::Fake).await;
1144 let fake_model = model.as_fake();
1145
1146 // Set up a profile with all tools enabled
1147 fs.insert_file(
1148 paths::settings_file(),
1149 json!({
1150 "agent": {
1151 "profiles": {
1152 "test": {
1153 "name": "Test Profile",
1154 "enable_all_context_servers": true,
1155 "tools": {
1156 EchoTool::name(): true,
1157 DelayTool::name(): true,
1158 WordListTool::name(): true,
1159 ToolRequiringPermission::name(): true,
1160 InfiniteTool::name(): true,
1161 }
1162 },
1163 }
1164 }
1165 })
1166 .to_string()
1167 .into_bytes(),
1168 )
1169 .await;
1170 cx.run_until_parked();
1171
1172 thread.update(cx, |thread, _| {
1173 thread.set_profile(AgentProfileId("test".into()));
1174 thread.add_tool(EchoTool);
1175 thread.add_tool(DelayTool);
1176 thread.add_tool(WordListTool);
1177 thread.add_tool(ToolRequiringPermission);
1178 thread.add_tool(InfiniteTool);
1179 });
1180
1181 // Set up multiple context servers with some overlapping tool names
1182 let _server1_calls = setup_context_server(
1183 "xxx",
1184 vec![
1185 context_server::types::Tool {
1186 name: "echo".into(), // Conflicts with native EchoTool
1187 description: None,
1188 input_schema: serde_json::to_value(EchoTool::input_schema(
1189 LanguageModelToolSchemaFormat::JsonSchema,
1190 ))
1191 .unwrap(),
1192 output_schema: None,
1193 annotations: None,
1194 },
1195 context_server::types::Tool {
1196 name: "unique_tool_1".into(),
1197 description: None,
1198 input_schema: json!({"type": "object", "properties": {}}),
1199 output_schema: None,
1200 annotations: None,
1201 },
1202 ],
1203 &context_server_store,
1204 cx,
1205 );
1206
1207 let _server2_calls = setup_context_server(
1208 "yyy",
1209 vec![
1210 context_server::types::Tool {
1211 name: "echo".into(), // Also conflicts with native EchoTool
1212 description: None,
1213 input_schema: serde_json::to_value(EchoTool::input_schema(
1214 LanguageModelToolSchemaFormat::JsonSchema,
1215 ))
1216 .unwrap(),
1217 output_schema: None,
1218 annotations: None,
1219 },
1220 context_server::types::Tool {
1221 name: "unique_tool_2".into(),
1222 description: None,
1223 input_schema: json!({"type": "object", "properties": {}}),
1224 output_schema: None,
1225 annotations: None,
1226 },
1227 context_server::types::Tool {
1228 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1229 description: None,
1230 input_schema: json!({"type": "object", "properties": {}}),
1231 output_schema: None,
1232 annotations: None,
1233 },
1234 context_server::types::Tool {
1235 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1236 description: None,
1237 input_schema: json!({"type": "object", "properties": {}}),
1238 output_schema: None,
1239 annotations: None,
1240 },
1241 ],
1242 &context_server_store,
1243 cx,
1244 );
1245 let _server3_calls = setup_context_server(
1246 "zzz",
1247 vec![
1248 context_server::types::Tool {
1249 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1250 description: None,
1251 input_schema: json!({"type": "object", "properties": {}}),
1252 output_schema: None,
1253 annotations: None,
1254 },
1255 context_server::types::Tool {
1256 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1257 description: None,
1258 input_schema: json!({"type": "object", "properties": {}}),
1259 output_schema: None,
1260 annotations: None,
1261 },
1262 context_server::types::Tool {
1263 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1264 description: None,
1265 input_schema: json!({"type": "object", "properties": {}}),
1266 output_schema: None,
1267 annotations: None,
1268 },
1269 ],
1270 &context_server_store,
1271 cx,
1272 );
1273
1274 thread
1275 .update(cx, |thread, cx| {
1276 thread.send(UserMessageId::new(), ["Go"], cx)
1277 })
1278 .unwrap();
1279 cx.run_until_parked();
1280 let completion = fake_model.pending_completions().pop().unwrap();
1281 assert_eq!(
1282 tool_names_for_completion(&completion),
1283 vec![
1284 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1285 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1286 "delay",
1287 "echo",
1288 "infinite",
1289 "tool_requiring_permission",
1290 "unique_tool_1",
1291 "unique_tool_2",
1292 "word_list",
1293 "xxx_echo",
1294 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1295 "yyy_echo",
1296 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1297 ]
1298 );
1299}
1300
1301#[gpui::test]
1302#[cfg_attr(not(feature = "e2e"), ignore)]
1303async fn test_cancellation(cx: &mut TestAppContext) {
1304 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1305
1306 let mut events = thread
1307 .update(cx, |thread, cx| {
1308 thread.add_tool(InfiniteTool);
1309 thread.add_tool(EchoTool);
1310 thread.send(
1311 UserMessageId::new(),
1312 ["Call the echo tool, then call the infinite tool, then explain their output"],
1313 cx,
1314 )
1315 })
1316 .unwrap();
1317
1318 // Wait until both tools are called.
1319 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1320 let mut echo_id = None;
1321 let mut echo_completed = false;
1322 while let Some(event) = events.next().await {
1323 match event.unwrap() {
1324 ThreadEvent::ToolCall(tool_call) => {
1325 assert_eq!(tool_call.title, expected_tools.remove(0));
1326 if tool_call.title == "Echo" {
1327 echo_id = Some(tool_call.id);
1328 }
1329 }
1330 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1331 acp::ToolCallUpdate {
1332 id,
1333 fields:
1334 acp::ToolCallUpdateFields {
1335 status: Some(acp::ToolCallStatus::Completed),
1336 ..
1337 },
1338 meta: None,
1339 },
1340 )) if Some(&id) == echo_id.as_ref() => {
1341 echo_completed = true;
1342 }
1343 _ => {}
1344 }
1345
1346 if expected_tools.is_empty() && echo_completed {
1347 break;
1348 }
1349 }
1350
1351 // Cancel the current send and ensure that the event stream is closed, even
1352 // if one of the tools is still running.
1353 thread.update(cx, |thread, cx| thread.cancel(cx));
1354 let events = events.collect::<Vec<_>>().await;
1355 let last_event = events.last();
1356 assert!(
1357 matches!(
1358 last_event,
1359 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1360 ),
1361 "unexpected event {last_event:?}"
1362 );
1363
1364 // Ensure we can still send a new message after cancellation.
1365 let events = thread
1366 .update(cx, |thread, cx| {
1367 thread.send(
1368 UserMessageId::new(),
1369 ["Testing: reply with 'Hello' then stop."],
1370 cx,
1371 )
1372 })
1373 .unwrap()
1374 .collect::<Vec<_>>()
1375 .await;
1376 thread.update(cx, |thread, _cx| {
1377 let message = thread.last_message().unwrap();
1378 let agent_message = message.as_agent_message().unwrap();
1379 assert_eq!(
1380 agent_message.content,
1381 vec![AgentMessageContent::Text("Hello".to_string())]
1382 );
1383 });
1384 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1385}
1386
1387#[gpui::test]
1388async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1389 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1390 let fake_model = model.as_fake();
1391
1392 let events_1 = thread
1393 .update(cx, |thread, cx| {
1394 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1395 })
1396 .unwrap();
1397 cx.run_until_parked();
1398 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1399 cx.run_until_parked();
1400
1401 let events_2 = thread
1402 .update(cx, |thread, cx| {
1403 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1404 })
1405 .unwrap();
1406 cx.run_until_parked();
1407 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1408 fake_model
1409 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1410 fake_model.end_last_completion_stream();
1411
1412 let events_1 = events_1.collect::<Vec<_>>().await;
1413 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1414 let events_2 = events_2.collect::<Vec<_>>().await;
1415 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1416}
1417
1418#[gpui::test]
1419async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1420 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1421 let fake_model = model.as_fake();
1422
1423 let events_1 = thread
1424 .update(cx, |thread, cx| {
1425 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1426 })
1427 .unwrap();
1428 cx.run_until_parked();
1429 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1430 fake_model
1431 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1432 fake_model.end_last_completion_stream();
1433 let events_1 = events_1.collect::<Vec<_>>().await;
1434
1435 let events_2 = thread
1436 .update(cx, |thread, cx| {
1437 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1438 })
1439 .unwrap();
1440 cx.run_until_parked();
1441 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1442 fake_model
1443 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1444 fake_model.end_last_completion_stream();
1445 let events_2 = events_2.collect::<Vec<_>>().await;
1446
1447 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1448 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1449}
1450
1451#[gpui::test]
1452async fn test_refusal(cx: &mut TestAppContext) {
1453 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1454 let fake_model = model.as_fake();
1455
1456 let events = thread
1457 .update(cx, |thread, cx| {
1458 thread.send(UserMessageId::new(), ["Hello"], cx)
1459 })
1460 .unwrap();
1461 cx.run_until_parked();
1462 thread.read_with(cx, |thread, _| {
1463 assert_eq!(
1464 thread.to_markdown(),
1465 indoc! {"
1466 ## User
1467
1468 Hello
1469 "}
1470 );
1471 });
1472
1473 fake_model.send_last_completion_stream_text_chunk("Hey!");
1474 cx.run_until_parked();
1475 thread.read_with(cx, |thread, _| {
1476 assert_eq!(
1477 thread.to_markdown(),
1478 indoc! {"
1479 ## User
1480
1481 Hello
1482
1483 ## Assistant
1484
1485 Hey!
1486 "}
1487 );
1488 });
1489
1490 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1491 fake_model
1492 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1493 let events = events.collect::<Vec<_>>().await;
1494 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1495 thread.read_with(cx, |thread, _| {
1496 assert_eq!(thread.to_markdown(), "");
1497 });
1498}
1499
1500#[gpui::test]
1501async fn test_truncate_first_message(cx: &mut TestAppContext) {
1502 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1503 let fake_model = model.as_fake();
1504
1505 let message_id = UserMessageId::new();
1506 thread
1507 .update(cx, |thread, cx| {
1508 thread.send(message_id.clone(), ["Hello"], cx)
1509 })
1510 .unwrap();
1511 cx.run_until_parked();
1512 thread.read_with(cx, |thread, _| {
1513 assert_eq!(
1514 thread.to_markdown(),
1515 indoc! {"
1516 ## User
1517
1518 Hello
1519 "}
1520 );
1521 assert_eq!(thread.latest_token_usage(), None);
1522 });
1523
1524 fake_model.send_last_completion_stream_text_chunk("Hey!");
1525 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1526 language_model::TokenUsage {
1527 input_tokens: 32_000,
1528 output_tokens: 16_000,
1529 cache_creation_input_tokens: 0,
1530 cache_read_input_tokens: 0,
1531 },
1532 ));
1533 cx.run_until_parked();
1534 thread.read_with(cx, |thread, _| {
1535 assert_eq!(
1536 thread.to_markdown(),
1537 indoc! {"
1538 ## User
1539
1540 Hello
1541
1542 ## Assistant
1543
1544 Hey!
1545 "}
1546 );
1547 assert_eq!(
1548 thread.latest_token_usage(),
1549 Some(acp_thread::TokenUsage {
1550 used_tokens: 32_000 + 16_000,
1551 max_tokens: 1_000_000,
1552 })
1553 );
1554 });
1555
1556 thread
1557 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1558 .unwrap();
1559 cx.run_until_parked();
1560 thread.read_with(cx, |thread, _| {
1561 assert_eq!(thread.to_markdown(), "");
1562 assert_eq!(thread.latest_token_usage(), None);
1563 });
1564
1565 // Ensure we can still send a new message after truncation.
1566 thread
1567 .update(cx, |thread, cx| {
1568 thread.send(UserMessageId::new(), ["Hi"], cx)
1569 })
1570 .unwrap();
1571 thread.update(cx, |thread, _cx| {
1572 assert_eq!(
1573 thread.to_markdown(),
1574 indoc! {"
1575 ## User
1576
1577 Hi
1578 "}
1579 );
1580 });
1581 cx.run_until_parked();
1582 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1583 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1584 language_model::TokenUsage {
1585 input_tokens: 40_000,
1586 output_tokens: 20_000,
1587 cache_creation_input_tokens: 0,
1588 cache_read_input_tokens: 0,
1589 },
1590 ));
1591 cx.run_until_parked();
1592 thread.read_with(cx, |thread, _| {
1593 assert_eq!(
1594 thread.to_markdown(),
1595 indoc! {"
1596 ## User
1597
1598 Hi
1599
1600 ## Assistant
1601
1602 Ahoy!
1603 "}
1604 );
1605
1606 assert_eq!(
1607 thread.latest_token_usage(),
1608 Some(acp_thread::TokenUsage {
1609 used_tokens: 40_000 + 20_000,
1610 max_tokens: 1_000_000,
1611 })
1612 );
1613 });
1614}
1615
1616#[gpui::test]
1617async fn test_truncate_second_message(cx: &mut TestAppContext) {
1618 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1619 let fake_model = model.as_fake();
1620
1621 thread
1622 .update(cx, |thread, cx| {
1623 thread.send(UserMessageId::new(), ["Message 1"], cx)
1624 })
1625 .unwrap();
1626 cx.run_until_parked();
1627 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1628 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1629 language_model::TokenUsage {
1630 input_tokens: 32_000,
1631 output_tokens: 16_000,
1632 cache_creation_input_tokens: 0,
1633 cache_read_input_tokens: 0,
1634 },
1635 ));
1636 fake_model.end_last_completion_stream();
1637 cx.run_until_parked();
1638
1639 let assert_first_message_state = |cx: &mut TestAppContext| {
1640 thread.clone().read_with(cx, |thread, _| {
1641 assert_eq!(
1642 thread.to_markdown(),
1643 indoc! {"
1644 ## User
1645
1646 Message 1
1647
1648 ## Assistant
1649
1650 Message 1 response
1651 "}
1652 );
1653
1654 assert_eq!(
1655 thread.latest_token_usage(),
1656 Some(acp_thread::TokenUsage {
1657 used_tokens: 32_000 + 16_000,
1658 max_tokens: 1_000_000,
1659 })
1660 );
1661 });
1662 };
1663
1664 assert_first_message_state(cx);
1665
1666 let second_message_id = UserMessageId::new();
1667 thread
1668 .update(cx, |thread, cx| {
1669 thread.send(second_message_id.clone(), ["Message 2"], cx)
1670 })
1671 .unwrap();
1672 cx.run_until_parked();
1673
1674 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1675 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1676 language_model::TokenUsage {
1677 input_tokens: 40_000,
1678 output_tokens: 20_000,
1679 cache_creation_input_tokens: 0,
1680 cache_read_input_tokens: 0,
1681 },
1682 ));
1683 fake_model.end_last_completion_stream();
1684 cx.run_until_parked();
1685
1686 thread.read_with(cx, |thread, _| {
1687 assert_eq!(
1688 thread.to_markdown(),
1689 indoc! {"
1690 ## User
1691
1692 Message 1
1693
1694 ## Assistant
1695
1696 Message 1 response
1697
1698 ## User
1699
1700 Message 2
1701
1702 ## Assistant
1703
1704 Message 2 response
1705 "}
1706 );
1707
1708 assert_eq!(
1709 thread.latest_token_usage(),
1710 Some(acp_thread::TokenUsage {
1711 used_tokens: 40_000 + 20_000,
1712 max_tokens: 1_000_000,
1713 })
1714 );
1715 });
1716
1717 thread
1718 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1719 .unwrap();
1720 cx.run_until_parked();
1721
1722 assert_first_message_state(cx);
1723}
1724
1725#[gpui::test]
1726async fn test_title_generation(cx: &mut TestAppContext) {
1727 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1728 let fake_model = model.as_fake();
1729
1730 let summary_model = Arc::new(FakeLanguageModel::default());
1731 thread.update(cx, |thread, cx| {
1732 thread.set_summarization_model(Some(summary_model.clone()), cx)
1733 });
1734
1735 let send = thread
1736 .update(cx, |thread, cx| {
1737 thread.send(UserMessageId::new(), ["Hello"], cx)
1738 })
1739 .unwrap();
1740 cx.run_until_parked();
1741
1742 fake_model.send_last_completion_stream_text_chunk("Hey!");
1743 fake_model.end_last_completion_stream();
1744 cx.run_until_parked();
1745 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1746
1747 // Ensure the summary model has been invoked to generate a title.
1748 summary_model.send_last_completion_stream_text_chunk("Hello ");
1749 summary_model.send_last_completion_stream_text_chunk("world\nG");
1750 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1751 summary_model.end_last_completion_stream();
1752 send.collect::<Vec<_>>().await;
1753 cx.run_until_parked();
1754 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1755
1756 // Send another message, ensuring no title is generated this time.
1757 let send = thread
1758 .update(cx, |thread, cx| {
1759 thread.send(UserMessageId::new(), ["Hello again"], cx)
1760 })
1761 .unwrap();
1762 cx.run_until_parked();
1763 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1764 fake_model.end_last_completion_stream();
1765 cx.run_until_parked();
1766 assert_eq!(summary_model.pending_completions(), Vec::new());
1767 send.collect::<Vec<_>>().await;
1768 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1769}
1770
1771#[gpui::test]
1772async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1773 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1774 let fake_model = model.as_fake();
1775
1776 let _events = thread
1777 .update(cx, |thread, cx| {
1778 thread.add_tool(ToolRequiringPermission);
1779 thread.add_tool(EchoTool);
1780 thread.send(UserMessageId::new(), ["Hey!"], cx)
1781 })
1782 .unwrap();
1783 cx.run_until_parked();
1784
1785 let permission_tool_use = LanguageModelToolUse {
1786 id: "tool_id_1".into(),
1787 name: ToolRequiringPermission::name().into(),
1788 raw_input: "{}".into(),
1789 input: json!({}),
1790 is_input_complete: true,
1791 };
1792 let echo_tool_use = LanguageModelToolUse {
1793 id: "tool_id_2".into(),
1794 name: EchoTool::name().into(),
1795 raw_input: json!({"text": "test"}).to_string(),
1796 input: json!({"text": "test"}),
1797 is_input_complete: true,
1798 };
1799 fake_model.send_last_completion_stream_text_chunk("Hi!");
1800 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1801 permission_tool_use,
1802 ));
1803 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1804 echo_tool_use.clone(),
1805 ));
1806 fake_model.end_last_completion_stream();
1807 cx.run_until_parked();
1808
1809 // Ensure pending tools are skipped when building a request.
1810 let request = thread
1811 .read_with(cx, |thread, cx| {
1812 thread.build_completion_request(CompletionIntent::EditFile, cx)
1813 })
1814 .unwrap();
1815 assert_eq!(
1816 request.messages[1..],
1817 vec![
1818 LanguageModelRequestMessage {
1819 role: Role::User,
1820 content: vec!["Hey!".into()],
1821 cache: true
1822 },
1823 LanguageModelRequestMessage {
1824 role: Role::Assistant,
1825 content: vec![
1826 MessageContent::Text("Hi!".into()),
1827 MessageContent::ToolUse(echo_tool_use.clone())
1828 ],
1829 cache: false
1830 },
1831 LanguageModelRequestMessage {
1832 role: Role::User,
1833 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1834 tool_use_id: echo_tool_use.id.clone(),
1835 tool_name: echo_tool_use.name,
1836 is_error: false,
1837 content: "test".into(),
1838 output: Some("test".into())
1839 })],
1840 cache: false
1841 },
1842 ],
1843 );
1844}
1845
1846#[gpui::test]
1847async fn test_agent_connection(cx: &mut TestAppContext) {
1848 cx.update(settings::init);
1849 let templates = Templates::new();
1850
1851 // Initialize language model system with test provider
1852 cx.update(|cx| {
1853 gpui_tokio::init(cx);
1854
1855 let http_client = FakeHttpClient::with_404_response();
1856 let clock = Arc::new(clock::FakeSystemClock::new());
1857 let client = Client::new(clock, http_client, cx);
1858 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1859 language_model::init(client.clone(), cx);
1860 language_models::init(user_store, client.clone(), cx);
1861 LanguageModelRegistry::test(cx);
1862 });
1863 cx.executor().forbid_parking();
1864
1865 // Create a project for new_thread
1866 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1867 fake_fs.insert_tree(path!("/test"), json!({})).await;
1868 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1869 let cwd = Path::new("/test");
1870 let text_thread_store =
1871 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1872 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1873
1874 // Create agent and connection
1875 let agent = NativeAgent::new(
1876 project.clone(),
1877 history_store,
1878 templates.clone(),
1879 None,
1880 fake_fs.clone(),
1881 &mut cx.to_async(),
1882 )
1883 .await
1884 .unwrap();
1885 let connection = NativeAgentConnection(agent.clone());
1886
1887 // Create a thread using new_thread
1888 let connection_rc = Rc::new(connection.clone());
1889 let acp_thread = cx
1890 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1891 .await
1892 .expect("new_thread should succeed");
1893
1894 // Get the session_id from the AcpThread
1895 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1896
1897 // Test model_selector returns Some
1898 let selector_opt = connection.model_selector(&session_id);
1899 assert!(
1900 selector_opt.is_some(),
1901 "agent should always support ModelSelector"
1902 );
1903 let selector = selector_opt.unwrap();
1904
1905 // Test list_models
1906 let listed_models = cx
1907 .update(|cx| selector.list_models(cx))
1908 .await
1909 .expect("list_models should succeed");
1910 let AgentModelList::Grouped(listed_models) = listed_models else {
1911 panic!("Unexpected model list type");
1912 };
1913 assert!(!listed_models.is_empty(), "should have at least one model");
1914 assert_eq!(
1915 listed_models[&AgentModelGroupName("Fake".into())][0]
1916 .id
1917 .0
1918 .as_ref(),
1919 "fake/fake"
1920 );
1921
1922 // Test selected_model returns the default
1923 let model = cx
1924 .update(|cx| selector.selected_model(cx))
1925 .await
1926 .expect("selected_model should succeed");
1927 let model = cx
1928 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1929 .unwrap();
1930 let model = model.as_fake();
1931 assert_eq!(model.id().0, "fake", "should return default model");
1932
1933 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1934 cx.run_until_parked();
1935 model.send_last_completion_stream_text_chunk("def");
1936 cx.run_until_parked();
1937 acp_thread.read_with(cx, |thread, cx| {
1938 assert_eq!(
1939 thread.to_markdown(cx),
1940 indoc! {"
1941 ## User
1942
1943 abc
1944
1945 ## Assistant
1946
1947 def
1948
1949 "}
1950 )
1951 });
1952
1953 // Test cancel
1954 cx.update(|cx| connection.cancel(&session_id, cx));
1955 request.await.expect("prompt should fail gracefully");
1956
1957 // Ensure that dropping the ACP thread causes the native thread to be
1958 // dropped as well.
1959 cx.update(|_| drop(acp_thread));
1960 let result = cx
1961 .update(|cx| {
1962 connection.prompt(
1963 Some(acp_thread::UserMessageId::new()),
1964 acp::PromptRequest {
1965 session_id: session_id.clone(),
1966 prompt: vec!["ghi".into()],
1967 meta: None,
1968 },
1969 cx,
1970 )
1971 })
1972 .await;
1973 assert_eq!(
1974 result.as_ref().unwrap_err().to_string(),
1975 "Session not found",
1976 "unexpected result: {:?}",
1977 result
1978 );
1979}
1980
1981#[gpui::test]
1982async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1983 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1984 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1985 let fake_model = model.as_fake();
1986
1987 let mut events = thread
1988 .update(cx, |thread, cx| {
1989 thread.send(UserMessageId::new(), ["Think"], cx)
1990 })
1991 .unwrap();
1992 cx.run_until_parked();
1993
1994 // Simulate streaming partial input.
1995 let input = json!({});
1996 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1997 LanguageModelToolUse {
1998 id: "1".into(),
1999 name: ThinkingTool::name().into(),
2000 raw_input: input.to_string(),
2001 input,
2002 is_input_complete: false,
2003 },
2004 ));
2005
2006 // Input streaming completed
2007 let input = json!({ "content": "Thinking hard!" });
2008 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2009 LanguageModelToolUse {
2010 id: "1".into(),
2011 name: "thinking".into(),
2012 raw_input: input.to_string(),
2013 input,
2014 is_input_complete: true,
2015 },
2016 ));
2017 fake_model.end_last_completion_stream();
2018 cx.run_until_parked();
2019
2020 let tool_call = expect_tool_call(&mut events).await;
2021 assert_eq!(
2022 tool_call,
2023 acp::ToolCall {
2024 id: acp::ToolCallId("1".into()),
2025 title: "Thinking".into(),
2026 kind: acp::ToolKind::Think,
2027 status: acp::ToolCallStatus::Pending,
2028 content: vec![],
2029 locations: vec![],
2030 raw_input: Some(json!({})),
2031 raw_output: None,
2032 meta: Some(json!({ "tool_name": "thinking" })),
2033 }
2034 );
2035 let update = expect_tool_call_update_fields(&mut events).await;
2036 assert_eq!(
2037 update,
2038 acp::ToolCallUpdate {
2039 id: acp::ToolCallId("1".into()),
2040 fields: acp::ToolCallUpdateFields {
2041 title: Some("Thinking".into()),
2042 kind: Some(acp::ToolKind::Think),
2043 raw_input: Some(json!({ "content": "Thinking hard!" })),
2044 ..Default::default()
2045 },
2046 meta: None,
2047 }
2048 );
2049 let update = expect_tool_call_update_fields(&mut events).await;
2050 assert_eq!(
2051 update,
2052 acp::ToolCallUpdate {
2053 id: acp::ToolCallId("1".into()),
2054 fields: acp::ToolCallUpdateFields {
2055 status: Some(acp::ToolCallStatus::InProgress),
2056 ..Default::default()
2057 },
2058 meta: None,
2059 }
2060 );
2061 let update = expect_tool_call_update_fields(&mut events).await;
2062 assert_eq!(
2063 update,
2064 acp::ToolCallUpdate {
2065 id: acp::ToolCallId("1".into()),
2066 fields: acp::ToolCallUpdateFields {
2067 content: Some(vec!["Thinking hard!".into()]),
2068 ..Default::default()
2069 },
2070 meta: None,
2071 }
2072 );
2073 let update = expect_tool_call_update_fields(&mut events).await;
2074 assert_eq!(
2075 update,
2076 acp::ToolCallUpdate {
2077 id: acp::ToolCallId("1".into()),
2078 fields: acp::ToolCallUpdateFields {
2079 status: Some(acp::ToolCallStatus::Completed),
2080 raw_output: Some("Finished thinking.".into()),
2081 ..Default::default()
2082 },
2083 meta: None,
2084 }
2085 );
2086}
2087
2088#[gpui::test]
2089async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2090 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2091 let fake_model = model.as_fake();
2092
2093 let mut events = thread
2094 .update(cx, |thread, cx| {
2095 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2096 thread.send(UserMessageId::new(), ["Hello!"], cx)
2097 })
2098 .unwrap();
2099 cx.run_until_parked();
2100
2101 fake_model.send_last_completion_stream_text_chunk("Hey!");
2102 fake_model.end_last_completion_stream();
2103
2104 let mut retry_events = Vec::new();
2105 while let Some(Ok(event)) = events.next().await {
2106 match event {
2107 ThreadEvent::Retry(retry_status) => {
2108 retry_events.push(retry_status);
2109 }
2110 ThreadEvent::Stop(..) => break,
2111 _ => {}
2112 }
2113 }
2114
2115 assert_eq!(retry_events.len(), 0);
2116 thread.read_with(cx, |thread, _cx| {
2117 assert_eq!(
2118 thread.to_markdown(),
2119 indoc! {"
2120 ## User
2121
2122 Hello!
2123
2124 ## Assistant
2125
2126 Hey!
2127 "}
2128 )
2129 });
2130}
2131
2132#[gpui::test]
2133async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2134 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2135 let fake_model = model.as_fake();
2136
2137 let mut events = thread
2138 .update(cx, |thread, cx| {
2139 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2140 thread.send(UserMessageId::new(), ["Hello!"], cx)
2141 })
2142 .unwrap();
2143 cx.run_until_parked();
2144
2145 fake_model.send_last_completion_stream_text_chunk("Hey,");
2146 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2147 provider: LanguageModelProviderName::new("Anthropic"),
2148 retry_after: Some(Duration::from_secs(3)),
2149 });
2150 fake_model.end_last_completion_stream();
2151
2152 cx.executor().advance_clock(Duration::from_secs(3));
2153 cx.run_until_parked();
2154
2155 fake_model.send_last_completion_stream_text_chunk("there!");
2156 fake_model.end_last_completion_stream();
2157 cx.run_until_parked();
2158
2159 let mut retry_events = Vec::new();
2160 while let Some(Ok(event)) = events.next().await {
2161 match event {
2162 ThreadEvent::Retry(retry_status) => {
2163 retry_events.push(retry_status);
2164 }
2165 ThreadEvent::Stop(..) => break,
2166 _ => {}
2167 }
2168 }
2169
2170 assert_eq!(retry_events.len(), 1);
2171 assert!(matches!(
2172 retry_events[0],
2173 acp_thread::RetryStatus { attempt: 1, .. }
2174 ));
2175 thread.read_with(cx, |thread, _cx| {
2176 assert_eq!(
2177 thread.to_markdown(),
2178 indoc! {"
2179 ## User
2180
2181 Hello!
2182
2183 ## Assistant
2184
2185 Hey,
2186
2187 [resume]
2188
2189 ## Assistant
2190
2191 there!
2192 "}
2193 )
2194 });
2195}
2196
2197#[gpui::test]
2198async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2199 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2200 let fake_model = model.as_fake();
2201
2202 let events = thread
2203 .update(cx, |thread, cx| {
2204 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2205 thread.add_tool(EchoTool);
2206 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2207 })
2208 .unwrap();
2209 cx.run_until_parked();
2210
2211 let tool_use_1 = LanguageModelToolUse {
2212 id: "tool_1".into(),
2213 name: EchoTool::name().into(),
2214 raw_input: json!({"text": "test"}).to_string(),
2215 input: json!({"text": "test"}),
2216 is_input_complete: true,
2217 };
2218 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2219 tool_use_1.clone(),
2220 ));
2221 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2222 provider: LanguageModelProviderName::new("Anthropic"),
2223 retry_after: Some(Duration::from_secs(3)),
2224 });
2225 fake_model.end_last_completion_stream();
2226
2227 cx.executor().advance_clock(Duration::from_secs(3));
2228 let completion = fake_model.pending_completions().pop().unwrap();
2229 assert_eq!(
2230 completion.messages[1..],
2231 vec![
2232 LanguageModelRequestMessage {
2233 role: Role::User,
2234 content: vec!["Call the echo tool!".into()],
2235 cache: false
2236 },
2237 LanguageModelRequestMessage {
2238 role: Role::Assistant,
2239 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2240 cache: false
2241 },
2242 LanguageModelRequestMessage {
2243 role: Role::User,
2244 content: vec![language_model::MessageContent::ToolResult(
2245 LanguageModelToolResult {
2246 tool_use_id: tool_use_1.id.clone(),
2247 tool_name: tool_use_1.name.clone(),
2248 is_error: false,
2249 content: "test".into(),
2250 output: Some("test".into())
2251 }
2252 )],
2253 cache: true
2254 },
2255 ]
2256 );
2257
2258 fake_model.send_last_completion_stream_text_chunk("Done");
2259 fake_model.end_last_completion_stream();
2260 cx.run_until_parked();
2261 events.collect::<Vec<_>>().await;
2262 thread.read_with(cx, |thread, _cx| {
2263 assert_eq!(
2264 thread.last_message(),
2265 Some(Message::Agent(AgentMessage {
2266 content: vec![AgentMessageContent::Text("Done".into())],
2267 tool_results: IndexMap::default()
2268 }))
2269 );
2270 })
2271}
2272
2273#[gpui::test]
2274async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2275 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2276 let fake_model = model.as_fake();
2277
2278 let mut events = thread
2279 .update(cx, |thread, cx| {
2280 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2281 thread.send(UserMessageId::new(), ["Hello!"], cx)
2282 })
2283 .unwrap();
2284 cx.run_until_parked();
2285
2286 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2287 fake_model.send_last_completion_stream_error(
2288 LanguageModelCompletionError::ServerOverloaded {
2289 provider: LanguageModelProviderName::new("Anthropic"),
2290 retry_after: Some(Duration::from_secs(3)),
2291 },
2292 );
2293 fake_model.end_last_completion_stream();
2294 cx.executor().advance_clock(Duration::from_secs(3));
2295 cx.run_until_parked();
2296 }
2297
2298 let mut errors = Vec::new();
2299 let mut retry_events = Vec::new();
2300 while let Some(event) = events.next().await {
2301 match event {
2302 Ok(ThreadEvent::Retry(retry_status)) => {
2303 retry_events.push(retry_status);
2304 }
2305 Ok(ThreadEvent::Stop(..)) => break,
2306 Err(error) => errors.push(error),
2307 _ => {}
2308 }
2309 }
2310
2311 assert_eq!(
2312 retry_events.len(),
2313 crate::thread::MAX_RETRY_ATTEMPTS as usize
2314 );
2315 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2316 assert_eq!(retry_events[i].attempt, i + 1);
2317 }
2318 assert_eq!(errors.len(), 1);
2319 let error = errors[0]
2320 .downcast_ref::<LanguageModelCompletionError>()
2321 .unwrap();
2322 assert!(matches!(
2323 error,
2324 LanguageModelCompletionError::ServerOverloaded { .. }
2325 ));
2326}
2327
2328/// Filters out the stop events for asserting against in tests
2329fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2330 result_events
2331 .into_iter()
2332 .filter_map(|event| match event.unwrap() {
2333 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2334 _ => None,
2335 })
2336 .collect()
2337}
2338
2339struct ThreadTest {
2340 model: Arc<dyn LanguageModel>,
2341 thread: Entity<Thread>,
2342 project_context: Entity<ProjectContext>,
2343 context_server_store: Entity<ContextServerStore>,
2344 fs: Arc<FakeFs>,
2345}
2346
2347enum TestModel {
2348 Sonnet4,
2349 Fake,
2350}
2351
2352impl TestModel {
2353 fn id(&self) -> LanguageModelId {
2354 match self {
2355 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2356 TestModel::Fake => unreachable!(),
2357 }
2358 }
2359}
2360
2361async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2362 cx.executor().allow_parking();
2363
2364 let fs = FakeFs::new(cx.background_executor.clone());
2365 fs.create_dir(paths::settings_file().parent().unwrap())
2366 .await
2367 .unwrap();
2368 fs.insert_file(
2369 paths::settings_file(),
2370 json!({
2371 "agent": {
2372 "default_profile": "test-profile",
2373 "profiles": {
2374 "test-profile": {
2375 "name": "Test Profile",
2376 "tools": {
2377 EchoTool::name(): true,
2378 DelayTool::name(): true,
2379 WordListTool::name(): true,
2380 ToolRequiringPermission::name(): true,
2381 InfiniteTool::name(): true,
2382 ThinkingTool::name(): true,
2383 }
2384 }
2385 }
2386 }
2387 })
2388 .to_string()
2389 .into_bytes(),
2390 )
2391 .await;
2392
2393 cx.update(|cx| {
2394 settings::init(cx);
2395
2396 match model {
2397 TestModel::Fake => {}
2398 TestModel::Sonnet4 => {
2399 gpui_tokio::init(cx);
2400 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2401 cx.set_http_client(Arc::new(http_client));
2402 let client = Client::production(cx);
2403 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2404 language_model::init(client.clone(), cx);
2405 language_models::init(user_store, client.clone(), cx);
2406 }
2407 };
2408
2409 watch_settings(fs.clone(), cx);
2410 });
2411
2412 let templates = Templates::new();
2413
2414 fs.insert_tree(path!("/test"), json!({})).await;
2415 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2416
2417 let model = cx
2418 .update(|cx| {
2419 if let TestModel::Fake = model {
2420 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2421 } else {
2422 let model_id = model.id();
2423 let models = LanguageModelRegistry::read_global(cx);
2424 let model = models
2425 .available_models(cx)
2426 .find(|model| model.id() == model_id)
2427 .unwrap();
2428
2429 let provider = models.provider(&model.provider_id()).unwrap();
2430 let authenticated = provider.authenticate(cx);
2431
2432 cx.spawn(async move |_cx| {
2433 authenticated.await.unwrap();
2434 model
2435 })
2436 }
2437 })
2438 .await;
2439
2440 let project_context = cx.new(|_cx| ProjectContext::default());
2441 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2442 let context_server_registry =
2443 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2444 let thread = cx.new(|cx| {
2445 Thread::new(
2446 project,
2447 project_context.clone(),
2448 context_server_registry,
2449 templates,
2450 Some(model.clone()),
2451 cx,
2452 )
2453 });
2454 ThreadTest {
2455 model,
2456 thread,
2457 project_context,
2458 context_server_store,
2459 fs,
2460 }
2461}
2462
2463#[cfg(test)]
2464#[ctor::ctor]
2465fn init_logger() {
2466 if std::env::var("RUST_LOG").is_ok() {
2467 env_logger::init();
2468 }
2469}
2470
2471fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2472 let fs = fs.clone();
2473 cx.spawn({
2474 async move |cx| {
2475 let mut new_settings_content_rx = settings::watch_config_file(
2476 cx.background_executor(),
2477 fs,
2478 paths::settings_file().clone(),
2479 );
2480
2481 while let Some(new_settings_content) = new_settings_content_rx.next().await {
2482 cx.update(|cx| {
2483 SettingsStore::update_global(cx, |settings, cx| {
2484 settings.set_user_settings(&new_settings_content, cx)
2485 })
2486 })
2487 .ok();
2488 }
2489 }
2490 })
2491 .detach();
2492}
2493
2494fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2495 completion
2496 .tools
2497 .iter()
2498 .map(|tool| tool.name.clone())
2499 .collect()
2500}
2501
2502fn setup_context_server(
2503 name: &'static str,
2504 tools: Vec<context_server::types::Tool>,
2505 context_server_store: &Entity<ContextServerStore>,
2506 cx: &mut TestAppContext,
2507) -> mpsc::UnboundedReceiver<(
2508 context_server::types::CallToolParams,
2509 oneshot::Sender<context_server::types::CallToolResponse>,
2510)> {
2511 cx.update(|cx| {
2512 let mut settings = ProjectSettings::get_global(cx).clone();
2513 settings.context_servers.insert(
2514 name.into(),
2515 project::project_settings::ContextServerSettings::Custom {
2516 enabled: true,
2517 command: ContextServerCommand {
2518 path: "somebinary".into(),
2519 args: Vec::new(),
2520 env: None,
2521 timeout: None,
2522 },
2523 },
2524 );
2525 ProjectSettings::override_global(settings, cx);
2526 });
2527
2528 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2529 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2530 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2531 context_server::types::InitializeResponse {
2532 protocol_version: context_server::types::ProtocolVersion(
2533 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2534 ),
2535 server_info: context_server::types::Implementation {
2536 name: name.into(),
2537 version: "1.0.0".to_string(),
2538 },
2539 capabilities: context_server::types::ServerCapabilities {
2540 tools: Some(context_server::types::ToolsCapabilities {
2541 list_changed: Some(true),
2542 }),
2543 ..Default::default()
2544 },
2545 meta: None,
2546 }
2547 })
2548 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2549 let tools = tools.clone();
2550 async move {
2551 context_server::types::ListToolsResponse {
2552 tools,
2553 next_cursor: None,
2554 meta: None,
2555 }
2556 }
2557 })
2558 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2559 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2560 async move {
2561 let (response_tx, response_rx) = oneshot::channel();
2562 mcp_tool_calls_tx
2563 .unbounded_send((params, response_tx))
2564 .unwrap();
2565 response_rx.await.unwrap()
2566 }
2567 });
2568 context_server_store.update(cx, |store, cx| {
2569 store.start_server(
2570 Arc::new(ContextServer::new(
2571 ContextServerId(name.into()),
2572 Arc::new(fake_transport),
2573 )),
2574 cx,
2575 );
2576 });
2577 cx.run_until_parked();
2578 mcp_tool_calls_rx
2579}