1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use fs::{FakeFs, Fs};
11use futures::{
12 StreamExt,
13 channel::{
14 mpsc::{self, UnboundedReceiver},
15 oneshot,
16 },
17};
18use gpui::{
19 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
20};
21use indoc::indoc;
22use language_model::{
23 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
24 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
25 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
26 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
27};
28use pretty_assertions::assert_eq;
29use project::{
30 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
31};
32use prompt_store::ProjectContext;
33use reqwest_client::ReqwestClient;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use serde_json::json;
37use settings::{Settings, SettingsStore};
38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
39use util::path;
40
41mod test_tools;
42use test_tools::*;
43
44#[gpui::test]
45async fn test_echo(cx: &mut TestAppContext) {
46 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
47 let fake_model = model.as_fake();
48
49 let events = thread
50 .update(cx, |thread, cx| {
51 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
52 })
53 .unwrap();
54 cx.run_until_parked();
55 fake_model.send_last_completion_stream_text_chunk("Hello");
56 fake_model
57 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
58 fake_model.end_last_completion_stream();
59
60 let events = events.collect().await;
61 thread.update(cx, |thread, _cx| {
62 assert_eq!(
63 thread.last_message().unwrap().to_markdown(),
64 indoc! {"
65 ## Assistant
66
67 Hello
68 "}
69 )
70 });
71 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
72}
73
74#[gpui::test]
75async fn test_thinking(cx: &mut TestAppContext) {
76 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
77 let fake_model = model.as_fake();
78
79 let events = thread
80 .update(cx, |thread, cx| {
81 thread.send(
82 UserMessageId::new(),
83 [indoc! {"
84 Testing:
85
86 Generate a thinking step where you just think the word 'Think',
87 and have your final answer be 'Hello'
88 "}],
89 cx,
90 )
91 })
92 .unwrap();
93 cx.run_until_parked();
94 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
95 text: "Think".to_string(),
96 signature: None,
97 });
98 fake_model.send_last_completion_stream_text_chunk("Hello");
99 fake_model
100 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
101 fake_model.end_last_completion_stream();
102
103 let events = events.collect().await;
104 thread.update(cx, |thread, _cx| {
105 assert_eq!(
106 thread.last_message().unwrap().to_markdown(),
107 indoc! {"
108 ## Assistant
109
110 <think>Think</think>
111 Hello
112 "}
113 )
114 });
115 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
116}
117
118#[gpui::test]
119async fn test_system_prompt(cx: &mut TestAppContext) {
120 let ThreadTest {
121 model,
122 thread,
123 project_context,
124 ..
125 } = setup(cx, TestModel::Fake).await;
126 let fake_model = model.as_fake();
127
128 project_context.update(cx, |project_context, _cx| {
129 project_context.shell = "test-shell".into()
130 });
131 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
132 thread
133 .update(cx, |thread, cx| {
134 thread.send(UserMessageId::new(), ["abc"], cx)
135 })
136 .unwrap();
137 cx.run_until_parked();
138 let mut pending_completions = fake_model.pending_completions();
139 assert_eq!(
140 pending_completions.len(),
141 1,
142 "unexpected pending completions: {:?}",
143 pending_completions
144 );
145
146 let pending_completion = pending_completions.pop().unwrap();
147 assert_eq!(pending_completion.messages[0].role, Role::System);
148
149 let system_message = &pending_completion.messages[0];
150 let system_prompt = system_message.content[0].to_str().unwrap();
151 assert!(
152 system_prompt.contains("test-shell"),
153 "unexpected system message: {:?}",
154 system_message
155 );
156 assert!(
157 system_prompt.contains("## Fixing Diagnostics"),
158 "unexpected system message: {:?}",
159 system_message
160 );
161}
162
163#[gpui::test]
164async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
165 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
166 let fake_model = model.as_fake();
167
168 thread
169 .update(cx, |thread, cx| {
170 thread.send(UserMessageId::new(), ["abc"], cx)
171 })
172 .unwrap();
173 cx.run_until_parked();
174 let mut pending_completions = fake_model.pending_completions();
175 assert_eq!(
176 pending_completions.len(),
177 1,
178 "unexpected pending completions: {:?}",
179 pending_completions
180 );
181
182 let pending_completion = pending_completions.pop().unwrap();
183 assert_eq!(pending_completion.messages[0].role, Role::System);
184
185 let system_message = &pending_completion.messages[0];
186 let system_prompt = system_message.content[0].to_str().unwrap();
187 assert!(
188 !system_prompt.contains("## Tool Use"),
189 "unexpected system message: {:?}",
190 system_message
191 );
192 assert!(
193 !system_prompt.contains("## Fixing Diagnostics"),
194 "unexpected system message: {:?}",
195 system_message
196 );
197}
198
199#[gpui::test]
200async fn test_prompt_caching(cx: &mut TestAppContext) {
201 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
202 let fake_model = model.as_fake();
203
204 // Send initial user message and verify it's cached
205 thread
206 .update(cx, |thread, cx| {
207 thread.send(UserMessageId::new(), ["Message 1"], cx)
208 })
209 .unwrap();
210 cx.run_until_parked();
211
212 let completion = fake_model.pending_completions().pop().unwrap();
213 assert_eq!(
214 completion.messages[1..],
215 vec![LanguageModelRequestMessage {
216 role: Role::User,
217 content: vec!["Message 1".into()],
218 cache: true
219 }]
220 );
221 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
222 "Response to Message 1".into(),
223 ));
224 fake_model.end_last_completion_stream();
225 cx.run_until_parked();
226
227 // Send another user message and verify only the latest is cached
228 thread
229 .update(cx, |thread, cx| {
230 thread.send(UserMessageId::new(), ["Message 2"], cx)
231 })
232 .unwrap();
233 cx.run_until_parked();
234
235 let completion = fake_model.pending_completions().pop().unwrap();
236 assert_eq!(
237 completion.messages[1..],
238 vec![
239 LanguageModelRequestMessage {
240 role: Role::User,
241 content: vec!["Message 1".into()],
242 cache: false
243 },
244 LanguageModelRequestMessage {
245 role: Role::Assistant,
246 content: vec!["Response to Message 1".into()],
247 cache: false
248 },
249 LanguageModelRequestMessage {
250 role: Role::User,
251 content: vec!["Message 2".into()],
252 cache: true
253 }
254 ]
255 );
256 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
257 "Response to Message 2".into(),
258 ));
259 fake_model.end_last_completion_stream();
260 cx.run_until_parked();
261
262 // Simulate a tool call and verify that the latest tool result is cached
263 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
264 thread
265 .update(cx, |thread, cx| {
266 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
267 })
268 .unwrap();
269 cx.run_until_parked();
270
271 let tool_use = LanguageModelToolUse {
272 id: "tool_1".into(),
273 name: EchoTool::name().into(),
274 raw_input: json!({"text": "test"}).to_string(),
275 input: json!({"text": "test"}),
276 is_input_complete: true,
277 thought_signature: None,
278 };
279 fake_model
280 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
281 fake_model.end_last_completion_stream();
282 cx.run_until_parked();
283
284 let completion = fake_model.pending_completions().pop().unwrap();
285 let tool_result = LanguageModelToolResult {
286 tool_use_id: "tool_1".into(),
287 tool_name: EchoTool::name().into(),
288 is_error: false,
289 content: "test".into(),
290 output: Some("test".into()),
291 };
292 assert_eq!(
293 completion.messages[1..],
294 vec![
295 LanguageModelRequestMessage {
296 role: Role::User,
297 content: vec!["Message 1".into()],
298 cache: false
299 },
300 LanguageModelRequestMessage {
301 role: Role::Assistant,
302 content: vec!["Response to Message 1".into()],
303 cache: false
304 },
305 LanguageModelRequestMessage {
306 role: Role::User,
307 content: vec!["Message 2".into()],
308 cache: false
309 },
310 LanguageModelRequestMessage {
311 role: Role::Assistant,
312 content: vec!["Response to Message 2".into()],
313 cache: false
314 },
315 LanguageModelRequestMessage {
316 role: Role::User,
317 content: vec!["Use the echo tool".into()],
318 cache: false
319 },
320 LanguageModelRequestMessage {
321 role: Role::Assistant,
322 content: vec![MessageContent::ToolUse(tool_use)],
323 cache: false
324 },
325 LanguageModelRequestMessage {
326 role: Role::User,
327 content: vec![MessageContent::ToolResult(tool_result)],
328 cache: true
329 }
330 ]
331 );
332}
333
334#[gpui::test]
335#[cfg_attr(not(feature = "e2e"), ignore)]
336async fn test_basic_tool_calls(cx: &mut TestAppContext) {
337 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
338
339 // Test a tool call that's likely to complete *before* streaming stops.
340 let events = thread
341 .update(cx, |thread, cx| {
342 thread.add_tool(EchoTool);
343 thread.send(
344 UserMessageId::new(),
345 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
346 cx,
347 )
348 })
349 .unwrap()
350 .collect()
351 .await;
352 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
353
354 // Test a tool calls that's likely to complete *after* streaming stops.
355 let events = thread
356 .update(cx, |thread, cx| {
357 thread.remove_tool(&EchoTool::name());
358 thread.add_tool(DelayTool);
359 thread.send(
360 UserMessageId::new(),
361 [
362 "Now call the delay tool with 200ms.",
363 "When the timer goes off, then you echo the output of the tool.",
364 ],
365 cx,
366 )
367 })
368 .unwrap()
369 .collect()
370 .await;
371 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
372 thread.update(cx, |thread, _cx| {
373 assert!(
374 thread
375 .last_message()
376 .unwrap()
377 .as_agent_message()
378 .unwrap()
379 .content
380 .iter()
381 .any(|content| {
382 if let AgentMessageContent::Text(text) = content {
383 text.contains("Ding")
384 } else {
385 false
386 }
387 }),
388 "{}",
389 thread.to_markdown()
390 );
391 });
392}
393
394#[gpui::test]
395#[cfg_attr(not(feature = "e2e"), ignore)]
396async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
397 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
398
399 // Test a tool call that's likely to complete *before* streaming stops.
400 let mut events = thread
401 .update(cx, |thread, cx| {
402 thread.add_tool(WordListTool);
403 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
404 })
405 .unwrap();
406
407 let mut saw_partial_tool_use = false;
408 while let Some(event) = events.next().await {
409 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
410 thread.update(cx, |thread, _cx| {
411 // Look for a tool use in the thread's last message
412 let message = thread.last_message().unwrap();
413 let agent_message = message.as_agent_message().unwrap();
414 let last_content = agent_message.content.last().unwrap();
415 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
416 assert_eq!(last_tool_use.name.as_ref(), "word_list");
417 if tool_call.status == acp::ToolCallStatus::Pending {
418 if !last_tool_use.is_input_complete
419 && last_tool_use.input.get("g").is_none()
420 {
421 saw_partial_tool_use = true;
422 }
423 } else {
424 last_tool_use
425 .input
426 .get("a")
427 .expect("'a' has streamed because input is now complete");
428 last_tool_use
429 .input
430 .get("g")
431 .expect("'g' has streamed because input is now complete");
432 }
433 } else {
434 panic!("last content should be a tool use");
435 }
436 });
437 }
438 }
439
440 assert!(
441 saw_partial_tool_use,
442 "should see at least one partially streamed tool use in the history"
443 );
444}
445
446#[gpui::test]
447async fn test_tool_authorization(cx: &mut TestAppContext) {
448 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
449 let fake_model = model.as_fake();
450
451 let mut events = thread
452 .update(cx, |thread, cx| {
453 thread.add_tool(ToolRequiringPermission);
454 thread.send(UserMessageId::new(), ["abc"], cx)
455 })
456 .unwrap();
457 cx.run_until_parked();
458 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
459 LanguageModelToolUse {
460 id: "tool_id_1".into(),
461 name: ToolRequiringPermission::name().into(),
462 raw_input: "{}".into(),
463 input: json!({}),
464 is_input_complete: true,
465 thought_signature: None,
466 },
467 ));
468 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
469 LanguageModelToolUse {
470 id: "tool_id_2".into(),
471 name: ToolRequiringPermission::name().into(),
472 raw_input: "{}".into(),
473 input: json!({}),
474 is_input_complete: true,
475 thought_signature: None,
476 },
477 ));
478 fake_model.end_last_completion_stream();
479 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
480 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
481
482 // Approve the first
483 tool_call_auth_1
484 .response
485 .send(tool_call_auth_1.options[1].id.clone())
486 .unwrap();
487 cx.run_until_parked();
488
489 // Reject the second
490 tool_call_auth_2
491 .response
492 .send(tool_call_auth_1.options[2].id.clone())
493 .unwrap();
494 cx.run_until_parked();
495
496 let completion = fake_model.pending_completions().pop().unwrap();
497 let message = completion.messages.last().unwrap();
498 assert_eq!(
499 message.content,
500 vec![
501 language_model::MessageContent::ToolResult(LanguageModelToolResult {
502 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
503 tool_name: ToolRequiringPermission::name().into(),
504 is_error: false,
505 content: "Allowed".into(),
506 output: Some("Allowed".into())
507 }),
508 language_model::MessageContent::ToolResult(LanguageModelToolResult {
509 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
510 tool_name: ToolRequiringPermission::name().into(),
511 is_error: true,
512 content: "Permission to run tool denied by user".into(),
513 output: Some("Permission to run tool denied by user".into())
514 })
515 ]
516 );
517
518 // Simulate yet another tool call.
519 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
520 LanguageModelToolUse {
521 id: "tool_id_3".into(),
522 name: ToolRequiringPermission::name().into(),
523 raw_input: "{}".into(),
524 input: json!({}),
525 is_input_complete: true,
526 thought_signature: None,
527 },
528 ));
529 fake_model.end_last_completion_stream();
530
531 // Respond by always allowing tools.
532 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
533 tool_call_auth_3
534 .response
535 .send(tool_call_auth_3.options[0].id.clone())
536 .unwrap();
537 cx.run_until_parked();
538 let completion = fake_model.pending_completions().pop().unwrap();
539 let message = completion.messages.last().unwrap();
540 assert_eq!(
541 message.content,
542 vec![language_model::MessageContent::ToolResult(
543 LanguageModelToolResult {
544 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
545 tool_name: ToolRequiringPermission::name().into(),
546 is_error: false,
547 content: "Allowed".into(),
548 output: Some("Allowed".into())
549 }
550 )]
551 );
552
553 // Simulate a final tool call, ensuring we don't trigger authorization.
554 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
555 LanguageModelToolUse {
556 id: "tool_id_4".into(),
557 name: ToolRequiringPermission::name().into(),
558 raw_input: "{}".into(),
559 input: json!({}),
560 is_input_complete: true,
561 thought_signature: None,
562 },
563 ));
564 fake_model.end_last_completion_stream();
565 cx.run_until_parked();
566 let completion = fake_model.pending_completions().pop().unwrap();
567 let message = completion.messages.last().unwrap();
568 assert_eq!(
569 message.content,
570 vec![language_model::MessageContent::ToolResult(
571 LanguageModelToolResult {
572 tool_use_id: "tool_id_4".into(),
573 tool_name: ToolRequiringPermission::name().into(),
574 is_error: false,
575 content: "Allowed".into(),
576 output: Some("Allowed".into())
577 }
578 )]
579 );
580}
581
582#[gpui::test]
583async fn test_tool_hallucination(cx: &mut TestAppContext) {
584 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
585 let fake_model = model.as_fake();
586
587 let mut events = thread
588 .update(cx, |thread, cx| {
589 thread.send(UserMessageId::new(), ["abc"], cx)
590 })
591 .unwrap();
592 cx.run_until_parked();
593 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
594 LanguageModelToolUse {
595 id: "tool_id_1".into(),
596 name: "nonexistent_tool".into(),
597 raw_input: "{}".into(),
598 input: json!({}),
599 is_input_complete: true,
600 thought_signature: None,
601 },
602 ));
603 fake_model.end_last_completion_stream();
604
605 let tool_call = expect_tool_call(&mut events).await;
606 assert_eq!(tool_call.title, "nonexistent_tool");
607 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
608 let update = expect_tool_call_update_fields(&mut events).await;
609 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
610}
611
612#[gpui::test]
613async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
614 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
615 let fake_model = model.as_fake();
616
617 let events = thread
618 .update(cx, |thread, cx| {
619 thread.add_tool(EchoTool);
620 thread.send(UserMessageId::new(), ["abc"], cx)
621 })
622 .unwrap();
623 cx.run_until_parked();
624 let tool_use = LanguageModelToolUse {
625 id: "tool_id_1".into(),
626 name: EchoTool::name().into(),
627 raw_input: "{}".into(),
628 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
629 is_input_complete: true,
630 thought_signature: None,
631 };
632 fake_model
633 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
634 fake_model.end_last_completion_stream();
635
636 cx.run_until_parked();
637 let completion = fake_model.pending_completions().pop().unwrap();
638 let tool_result = LanguageModelToolResult {
639 tool_use_id: "tool_id_1".into(),
640 tool_name: EchoTool::name().into(),
641 is_error: false,
642 content: "def".into(),
643 output: Some("def".into()),
644 };
645 assert_eq!(
646 completion.messages[1..],
647 vec![
648 LanguageModelRequestMessage {
649 role: Role::User,
650 content: vec!["abc".into()],
651 cache: false
652 },
653 LanguageModelRequestMessage {
654 role: Role::Assistant,
655 content: vec![MessageContent::ToolUse(tool_use.clone())],
656 cache: false
657 },
658 LanguageModelRequestMessage {
659 role: Role::User,
660 content: vec![MessageContent::ToolResult(tool_result.clone())],
661 cache: true
662 },
663 ]
664 );
665
666 // Simulate reaching tool use limit.
667 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
668 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
669 ));
670 fake_model.end_last_completion_stream();
671 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
672 assert!(
673 last_event
674 .unwrap_err()
675 .is::<language_model::ToolUseLimitReachedError>()
676 );
677
678 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
679 cx.run_until_parked();
680 let completion = fake_model.pending_completions().pop().unwrap();
681 assert_eq!(
682 completion.messages[1..],
683 vec![
684 LanguageModelRequestMessage {
685 role: Role::User,
686 content: vec!["abc".into()],
687 cache: false
688 },
689 LanguageModelRequestMessage {
690 role: Role::Assistant,
691 content: vec![MessageContent::ToolUse(tool_use)],
692 cache: false
693 },
694 LanguageModelRequestMessage {
695 role: Role::User,
696 content: vec![MessageContent::ToolResult(tool_result)],
697 cache: false
698 },
699 LanguageModelRequestMessage {
700 role: Role::User,
701 content: vec!["Continue where you left off".into()],
702 cache: true
703 }
704 ]
705 );
706
707 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
708 fake_model.end_last_completion_stream();
709 events.collect::<Vec<_>>().await;
710 thread.read_with(cx, |thread, _cx| {
711 assert_eq!(
712 thread.last_message().unwrap().to_markdown(),
713 indoc! {"
714 ## Assistant
715
716 Done
717 "}
718 )
719 });
720}
721
722#[gpui::test]
723async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
724 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
725 let fake_model = model.as_fake();
726
727 let events = thread
728 .update(cx, |thread, cx| {
729 thread.add_tool(EchoTool);
730 thread.send(UserMessageId::new(), ["abc"], cx)
731 })
732 .unwrap();
733 cx.run_until_parked();
734
735 let tool_use = LanguageModelToolUse {
736 id: "tool_id_1".into(),
737 name: EchoTool::name().into(),
738 raw_input: "{}".into(),
739 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
740 is_input_complete: true,
741 thought_signature: None,
742 };
743 let tool_result = LanguageModelToolResult {
744 tool_use_id: "tool_id_1".into(),
745 tool_name: EchoTool::name().into(),
746 is_error: false,
747 content: "def".into(),
748 output: Some("def".into()),
749 };
750 fake_model
751 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
752 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
753 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
754 ));
755 fake_model.end_last_completion_stream();
756 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
757 assert!(
758 last_event
759 .unwrap_err()
760 .is::<language_model::ToolUseLimitReachedError>()
761 );
762
763 thread
764 .update(cx, |thread, cx| {
765 thread.send(UserMessageId::new(), vec!["ghi"], cx)
766 })
767 .unwrap();
768 cx.run_until_parked();
769 let completion = fake_model.pending_completions().pop().unwrap();
770 assert_eq!(
771 completion.messages[1..],
772 vec![
773 LanguageModelRequestMessage {
774 role: Role::User,
775 content: vec!["abc".into()],
776 cache: false
777 },
778 LanguageModelRequestMessage {
779 role: Role::Assistant,
780 content: vec![MessageContent::ToolUse(tool_use)],
781 cache: false
782 },
783 LanguageModelRequestMessage {
784 role: Role::User,
785 content: vec![MessageContent::ToolResult(tool_result)],
786 cache: false
787 },
788 LanguageModelRequestMessage {
789 role: Role::User,
790 content: vec!["ghi".into()],
791 cache: true
792 }
793 ]
794 );
795}
796
797async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
798 let event = events
799 .next()
800 .await
801 .expect("no tool call authorization event received")
802 .unwrap();
803 match event {
804 ThreadEvent::ToolCall(tool_call) => tool_call,
805 event => {
806 panic!("Unexpected event {event:?}");
807 }
808 }
809}
810
811async fn expect_tool_call_update_fields(
812 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
813) -> acp::ToolCallUpdate {
814 let event = events
815 .next()
816 .await
817 .expect("no tool call authorization event received")
818 .unwrap();
819 match event {
820 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
821 event => {
822 panic!("Unexpected event {event:?}");
823 }
824 }
825}
826
827async fn next_tool_call_authorization(
828 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
829) -> ToolCallAuthorization {
830 loop {
831 let event = events
832 .next()
833 .await
834 .expect("no tool call authorization event received")
835 .unwrap();
836 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
837 let permission_kinds = tool_call_authorization
838 .options
839 .iter()
840 .map(|o| o.kind)
841 .collect::<Vec<_>>();
842 assert_eq!(
843 permission_kinds,
844 vec![
845 acp::PermissionOptionKind::AllowAlways,
846 acp::PermissionOptionKind::AllowOnce,
847 acp::PermissionOptionKind::RejectOnce,
848 ]
849 );
850 return tool_call_authorization;
851 }
852 }
853}
854
855#[gpui::test]
856#[cfg_attr(not(feature = "e2e"), ignore)]
857async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
858 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
859
860 // Test concurrent tool calls with different delay times
861 let events = thread
862 .update(cx, |thread, cx| {
863 thread.add_tool(DelayTool);
864 thread.send(
865 UserMessageId::new(),
866 [
867 "Call the delay tool twice in the same message.",
868 "Once with 100ms. Once with 300ms.",
869 "When both timers are complete, describe the outputs.",
870 ],
871 cx,
872 )
873 })
874 .unwrap()
875 .collect()
876 .await;
877
878 let stop_reasons = stop_events(events);
879 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
880
881 thread.update(cx, |thread, _cx| {
882 let last_message = thread.last_message().unwrap();
883 let agent_message = last_message.as_agent_message().unwrap();
884 let text = agent_message
885 .content
886 .iter()
887 .filter_map(|content| {
888 if let AgentMessageContent::Text(text) = content {
889 Some(text.as_str())
890 } else {
891 None
892 }
893 })
894 .collect::<String>();
895
896 assert!(text.contains("Ding"));
897 });
898}
899
900#[gpui::test]
901async fn test_profiles(cx: &mut TestAppContext) {
902 let ThreadTest {
903 model, thread, fs, ..
904 } = setup(cx, TestModel::Fake).await;
905 let fake_model = model.as_fake();
906
907 thread.update(cx, |thread, _cx| {
908 thread.add_tool(DelayTool);
909 thread.add_tool(EchoTool);
910 thread.add_tool(InfiniteTool);
911 });
912
913 // Override profiles and wait for settings to be loaded.
914 fs.insert_file(
915 paths::settings_file(),
916 json!({
917 "agent": {
918 "profiles": {
919 "test-1": {
920 "name": "Test Profile 1",
921 "tools": {
922 EchoTool::name(): true,
923 DelayTool::name(): true,
924 }
925 },
926 "test-2": {
927 "name": "Test Profile 2",
928 "tools": {
929 InfiniteTool::name(): true,
930 }
931 }
932 }
933 }
934 })
935 .to_string()
936 .into_bytes(),
937 )
938 .await;
939 cx.run_until_parked();
940
941 // Test that test-1 profile (default) has echo and delay tools
942 thread
943 .update(cx, |thread, cx| {
944 thread.set_profile(AgentProfileId("test-1".into()), cx);
945 thread.send(UserMessageId::new(), ["test"], cx)
946 })
947 .unwrap();
948 cx.run_until_parked();
949
950 let mut pending_completions = fake_model.pending_completions();
951 assert_eq!(pending_completions.len(), 1);
952 let completion = pending_completions.pop().unwrap();
953 let tool_names: Vec<String> = completion
954 .tools
955 .iter()
956 .map(|tool| tool.name.clone())
957 .collect();
958 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
959 fake_model.end_last_completion_stream();
960
961 // Switch to test-2 profile, and verify that it has only the infinite tool.
962 thread
963 .update(cx, |thread, cx| {
964 thread.set_profile(AgentProfileId("test-2".into()), cx);
965 thread.send(UserMessageId::new(), ["test2"], cx)
966 })
967 .unwrap();
968 cx.run_until_parked();
969 let mut pending_completions = fake_model.pending_completions();
970 assert_eq!(pending_completions.len(), 1);
971 let completion = pending_completions.pop().unwrap();
972 let tool_names: Vec<String> = completion
973 .tools
974 .iter()
975 .map(|tool| tool.name.clone())
976 .collect();
977 assert_eq!(tool_names, vec![InfiniteTool::name()]);
978}
979
980#[gpui::test]
981async fn test_mcp_tools(cx: &mut TestAppContext) {
982 let ThreadTest {
983 model,
984 thread,
985 context_server_store,
986 fs,
987 ..
988 } = setup(cx, TestModel::Fake).await;
989 let fake_model = model.as_fake();
990
991 // Override profiles and wait for settings to be loaded.
992 fs.insert_file(
993 paths::settings_file(),
994 json!({
995 "agent": {
996 "always_allow_tool_actions": true,
997 "profiles": {
998 "test": {
999 "name": "Test Profile",
1000 "enable_all_context_servers": true,
1001 "tools": {
1002 EchoTool::name(): true,
1003 }
1004 },
1005 }
1006 }
1007 })
1008 .to_string()
1009 .into_bytes(),
1010 )
1011 .await;
1012 cx.run_until_parked();
1013 thread.update(cx, |thread, cx| {
1014 thread.set_profile(AgentProfileId("test".into()), cx)
1015 });
1016
1017 let mut mcp_tool_calls = setup_context_server(
1018 "test_server",
1019 vec![context_server::types::Tool {
1020 name: "echo".into(),
1021 description: None,
1022 input_schema: serde_json::to_value(EchoTool::input_schema(
1023 LanguageModelToolSchemaFormat::JsonSchema,
1024 ))
1025 .unwrap(),
1026 output_schema: None,
1027 annotations: None,
1028 }],
1029 &context_server_store,
1030 cx,
1031 );
1032
1033 let events = thread.update(cx, |thread, cx| {
1034 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1035 });
1036 cx.run_until_parked();
1037
1038 // Simulate the model calling the MCP tool.
1039 let completion = fake_model.pending_completions().pop().unwrap();
1040 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1041 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1042 LanguageModelToolUse {
1043 id: "tool_1".into(),
1044 name: "echo".into(),
1045 raw_input: json!({"text": "test"}).to_string(),
1046 input: json!({"text": "test"}),
1047 is_input_complete: true,
1048 thought_signature: None,
1049 },
1050 ));
1051 fake_model.end_last_completion_stream();
1052 cx.run_until_parked();
1053
1054 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1055 assert_eq!(tool_call_params.name, "echo");
1056 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1057 tool_call_response
1058 .send(context_server::types::CallToolResponse {
1059 content: vec![context_server::types::ToolResponseContent::Text {
1060 text: "test".into(),
1061 }],
1062 is_error: None,
1063 meta: None,
1064 structured_content: None,
1065 })
1066 .unwrap();
1067 cx.run_until_parked();
1068
1069 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1070 fake_model.send_last_completion_stream_text_chunk("Done!");
1071 fake_model.end_last_completion_stream();
1072 events.collect::<Vec<_>>().await;
1073
1074 // Send again after adding the echo tool, ensuring the name collision is resolved.
1075 let events = thread.update(cx, |thread, cx| {
1076 thread.add_tool(EchoTool);
1077 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1078 });
1079 cx.run_until_parked();
1080 let completion = fake_model.pending_completions().pop().unwrap();
1081 assert_eq!(
1082 tool_names_for_completion(&completion),
1083 vec!["echo", "test_server_echo"]
1084 );
1085 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1086 LanguageModelToolUse {
1087 id: "tool_2".into(),
1088 name: "test_server_echo".into(),
1089 raw_input: json!({"text": "mcp"}).to_string(),
1090 input: json!({"text": "mcp"}),
1091 is_input_complete: true,
1092 thought_signature: None,
1093 },
1094 ));
1095 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1096 LanguageModelToolUse {
1097 id: "tool_3".into(),
1098 name: "echo".into(),
1099 raw_input: json!({"text": "native"}).to_string(),
1100 input: json!({"text": "native"}),
1101 is_input_complete: true,
1102 thought_signature: None,
1103 },
1104 ));
1105 fake_model.end_last_completion_stream();
1106 cx.run_until_parked();
1107
1108 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1109 assert_eq!(tool_call_params.name, "echo");
1110 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1111 tool_call_response
1112 .send(context_server::types::CallToolResponse {
1113 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1114 is_error: None,
1115 meta: None,
1116 structured_content: None,
1117 })
1118 .unwrap();
1119 cx.run_until_parked();
1120
1121 // Ensure the tool results were inserted with the correct names.
1122 let completion = fake_model.pending_completions().pop().unwrap();
1123 assert_eq!(
1124 completion.messages.last().unwrap().content,
1125 vec![
1126 MessageContent::ToolResult(LanguageModelToolResult {
1127 tool_use_id: "tool_3".into(),
1128 tool_name: "echo".into(),
1129 is_error: false,
1130 content: "native".into(),
1131 output: Some("native".into()),
1132 },),
1133 MessageContent::ToolResult(LanguageModelToolResult {
1134 tool_use_id: "tool_2".into(),
1135 tool_name: "test_server_echo".into(),
1136 is_error: false,
1137 content: "mcp".into(),
1138 output: Some("mcp".into()),
1139 },),
1140 ]
1141 );
1142 fake_model.end_last_completion_stream();
1143 events.collect::<Vec<_>>().await;
1144}
1145
1146#[gpui::test]
1147async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1148 let ThreadTest {
1149 model,
1150 thread,
1151 context_server_store,
1152 fs,
1153 ..
1154 } = setup(cx, TestModel::Fake).await;
1155 let fake_model = model.as_fake();
1156
1157 // Set up a profile with all tools enabled
1158 fs.insert_file(
1159 paths::settings_file(),
1160 json!({
1161 "agent": {
1162 "profiles": {
1163 "test": {
1164 "name": "Test Profile",
1165 "enable_all_context_servers": true,
1166 "tools": {
1167 EchoTool::name(): true,
1168 DelayTool::name(): true,
1169 WordListTool::name(): true,
1170 ToolRequiringPermission::name(): true,
1171 InfiniteTool::name(): true,
1172 }
1173 },
1174 }
1175 }
1176 })
1177 .to_string()
1178 .into_bytes(),
1179 )
1180 .await;
1181 cx.run_until_parked();
1182
1183 thread.update(cx, |thread, cx| {
1184 thread.set_profile(AgentProfileId("test".into()), cx);
1185 thread.add_tool(EchoTool);
1186 thread.add_tool(DelayTool);
1187 thread.add_tool(WordListTool);
1188 thread.add_tool(ToolRequiringPermission);
1189 thread.add_tool(InfiniteTool);
1190 });
1191
1192 // Set up multiple context servers with some overlapping tool names
1193 let _server1_calls = setup_context_server(
1194 "xxx",
1195 vec![
1196 context_server::types::Tool {
1197 name: "echo".into(), // Conflicts with native EchoTool
1198 description: None,
1199 input_schema: serde_json::to_value(EchoTool::input_schema(
1200 LanguageModelToolSchemaFormat::JsonSchema,
1201 ))
1202 .unwrap(),
1203 output_schema: None,
1204 annotations: None,
1205 },
1206 context_server::types::Tool {
1207 name: "unique_tool_1".into(),
1208 description: None,
1209 input_schema: json!({"type": "object", "properties": {}}),
1210 output_schema: None,
1211 annotations: None,
1212 },
1213 ],
1214 &context_server_store,
1215 cx,
1216 );
1217
1218 let _server2_calls = setup_context_server(
1219 "yyy",
1220 vec![
1221 context_server::types::Tool {
1222 name: "echo".into(), // Also conflicts with native EchoTool
1223 description: None,
1224 input_schema: serde_json::to_value(EchoTool::input_schema(
1225 LanguageModelToolSchemaFormat::JsonSchema,
1226 ))
1227 .unwrap(),
1228 output_schema: None,
1229 annotations: None,
1230 },
1231 context_server::types::Tool {
1232 name: "unique_tool_2".into(),
1233 description: None,
1234 input_schema: json!({"type": "object", "properties": {}}),
1235 output_schema: None,
1236 annotations: None,
1237 },
1238 context_server::types::Tool {
1239 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1240 description: None,
1241 input_schema: json!({"type": "object", "properties": {}}),
1242 output_schema: None,
1243 annotations: None,
1244 },
1245 context_server::types::Tool {
1246 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1247 description: None,
1248 input_schema: json!({"type": "object", "properties": {}}),
1249 output_schema: None,
1250 annotations: None,
1251 },
1252 ],
1253 &context_server_store,
1254 cx,
1255 );
1256 let _server3_calls = setup_context_server(
1257 "zzz",
1258 vec![
1259 context_server::types::Tool {
1260 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1261 description: None,
1262 input_schema: json!({"type": "object", "properties": {}}),
1263 output_schema: None,
1264 annotations: None,
1265 },
1266 context_server::types::Tool {
1267 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1268 description: None,
1269 input_schema: json!({"type": "object", "properties": {}}),
1270 output_schema: None,
1271 annotations: None,
1272 },
1273 context_server::types::Tool {
1274 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1275 description: None,
1276 input_schema: json!({"type": "object", "properties": {}}),
1277 output_schema: None,
1278 annotations: None,
1279 },
1280 ],
1281 &context_server_store,
1282 cx,
1283 );
1284
1285 thread
1286 .update(cx, |thread, cx| {
1287 thread.send(UserMessageId::new(), ["Go"], cx)
1288 })
1289 .unwrap();
1290 cx.run_until_parked();
1291 let completion = fake_model.pending_completions().pop().unwrap();
1292 assert_eq!(
1293 tool_names_for_completion(&completion),
1294 vec![
1295 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1296 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1297 "delay",
1298 "echo",
1299 "infinite",
1300 "tool_requiring_permission",
1301 "unique_tool_1",
1302 "unique_tool_2",
1303 "word_list",
1304 "xxx_echo",
1305 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1306 "yyy_echo",
1307 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1308 ]
1309 );
1310}
1311
1312#[gpui::test]
1313#[cfg_attr(not(feature = "e2e"), ignore)]
1314async fn test_cancellation(cx: &mut TestAppContext) {
1315 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1316
1317 let mut events = thread
1318 .update(cx, |thread, cx| {
1319 thread.add_tool(InfiniteTool);
1320 thread.add_tool(EchoTool);
1321 thread.send(
1322 UserMessageId::new(),
1323 ["Call the echo tool, then call the infinite tool, then explain their output"],
1324 cx,
1325 )
1326 })
1327 .unwrap();
1328
1329 // Wait until both tools are called.
1330 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1331 let mut echo_id = None;
1332 let mut echo_completed = false;
1333 while let Some(event) = events.next().await {
1334 match event.unwrap() {
1335 ThreadEvent::ToolCall(tool_call) => {
1336 assert_eq!(tool_call.title, expected_tools.remove(0));
1337 if tool_call.title == "Echo" {
1338 echo_id = Some(tool_call.id);
1339 }
1340 }
1341 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1342 acp::ToolCallUpdate {
1343 id,
1344 fields:
1345 acp::ToolCallUpdateFields {
1346 status: Some(acp::ToolCallStatus::Completed),
1347 ..
1348 },
1349 meta: None,
1350 },
1351 )) if Some(&id) == echo_id.as_ref() => {
1352 echo_completed = true;
1353 }
1354 _ => {}
1355 }
1356
1357 if expected_tools.is_empty() && echo_completed {
1358 break;
1359 }
1360 }
1361
1362 // Cancel the current send and ensure that the event stream is closed, even
1363 // if one of the tools is still running.
1364 thread.update(cx, |thread, cx| thread.cancel(cx));
1365 let events = events.collect::<Vec<_>>().await;
1366 let last_event = events.last();
1367 assert!(
1368 matches!(
1369 last_event,
1370 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1371 ),
1372 "unexpected event {last_event:?}"
1373 );
1374
1375 // Ensure we can still send a new message after cancellation.
1376 let events = thread
1377 .update(cx, |thread, cx| {
1378 thread.send(
1379 UserMessageId::new(),
1380 ["Testing: reply with 'Hello' then stop."],
1381 cx,
1382 )
1383 })
1384 .unwrap()
1385 .collect::<Vec<_>>()
1386 .await;
1387 thread.update(cx, |thread, _cx| {
1388 let message = thread.last_message().unwrap();
1389 let agent_message = message.as_agent_message().unwrap();
1390 assert_eq!(
1391 agent_message.content,
1392 vec![AgentMessageContent::Text("Hello".to_string())]
1393 );
1394 });
1395 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1396}
1397
1398#[gpui::test]
1399async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1400 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1401 let fake_model = model.as_fake();
1402
1403 let events_1 = thread
1404 .update(cx, |thread, cx| {
1405 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1406 })
1407 .unwrap();
1408 cx.run_until_parked();
1409 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1410 cx.run_until_parked();
1411
1412 let events_2 = thread
1413 .update(cx, |thread, cx| {
1414 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1415 })
1416 .unwrap();
1417 cx.run_until_parked();
1418 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1419 fake_model
1420 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1421 fake_model.end_last_completion_stream();
1422
1423 let events_1 = events_1.collect::<Vec<_>>().await;
1424 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1425 let events_2 = events_2.collect::<Vec<_>>().await;
1426 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1427}
1428
1429#[gpui::test]
1430async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1431 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1432 let fake_model = model.as_fake();
1433
1434 let events_1 = thread
1435 .update(cx, |thread, cx| {
1436 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1437 })
1438 .unwrap();
1439 cx.run_until_parked();
1440 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1441 fake_model
1442 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1443 fake_model.end_last_completion_stream();
1444 let events_1 = events_1.collect::<Vec<_>>().await;
1445
1446 let events_2 = thread
1447 .update(cx, |thread, cx| {
1448 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1449 })
1450 .unwrap();
1451 cx.run_until_parked();
1452 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1453 fake_model
1454 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1455 fake_model.end_last_completion_stream();
1456 let events_2 = events_2.collect::<Vec<_>>().await;
1457
1458 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1459 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1460}
1461
1462#[gpui::test]
1463async fn test_refusal(cx: &mut TestAppContext) {
1464 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1465 let fake_model = model.as_fake();
1466
1467 let events = thread
1468 .update(cx, |thread, cx| {
1469 thread.send(UserMessageId::new(), ["Hello"], cx)
1470 })
1471 .unwrap();
1472 cx.run_until_parked();
1473 thread.read_with(cx, |thread, _| {
1474 assert_eq!(
1475 thread.to_markdown(),
1476 indoc! {"
1477 ## User
1478
1479 Hello
1480 "}
1481 );
1482 });
1483
1484 fake_model.send_last_completion_stream_text_chunk("Hey!");
1485 cx.run_until_parked();
1486 thread.read_with(cx, |thread, _| {
1487 assert_eq!(
1488 thread.to_markdown(),
1489 indoc! {"
1490 ## User
1491
1492 Hello
1493
1494 ## Assistant
1495
1496 Hey!
1497 "}
1498 );
1499 });
1500
1501 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1502 fake_model
1503 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1504 let events = events.collect::<Vec<_>>().await;
1505 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1506 thread.read_with(cx, |thread, _| {
1507 assert_eq!(thread.to_markdown(), "");
1508 });
1509}
1510
1511#[gpui::test]
1512async fn test_truncate_first_message(cx: &mut TestAppContext) {
1513 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1514 let fake_model = model.as_fake();
1515
1516 let message_id = UserMessageId::new();
1517 thread
1518 .update(cx, |thread, cx| {
1519 thread.send(message_id.clone(), ["Hello"], cx)
1520 })
1521 .unwrap();
1522 cx.run_until_parked();
1523 thread.read_with(cx, |thread, _| {
1524 assert_eq!(
1525 thread.to_markdown(),
1526 indoc! {"
1527 ## User
1528
1529 Hello
1530 "}
1531 );
1532 assert_eq!(thread.latest_token_usage(), None);
1533 });
1534
1535 fake_model.send_last_completion_stream_text_chunk("Hey!");
1536 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1537 language_model::TokenUsage {
1538 input_tokens: 32_000,
1539 output_tokens: 16_000,
1540 cache_creation_input_tokens: 0,
1541 cache_read_input_tokens: 0,
1542 },
1543 ));
1544 cx.run_until_parked();
1545 thread.read_with(cx, |thread, _| {
1546 assert_eq!(
1547 thread.to_markdown(),
1548 indoc! {"
1549 ## User
1550
1551 Hello
1552
1553 ## Assistant
1554
1555 Hey!
1556 "}
1557 );
1558 assert_eq!(
1559 thread.latest_token_usage(),
1560 Some(acp_thread::TokenUsage {
1561 used_tokens: 32_000 + 16_000,
1562 max_tokens: 1_000_000,
1563 })
1564 );
1565 });
1566
1567 thread
1568 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1569 .unwrap();
1570 cx.run_until_parked();
1571 thread.read_with(cx, |thread, _| {
1572 assert_eq!(thread.to_markdown(), "");
1573 assert_eq!(thread.latest_token_usage(), None);
1574 });
1575
1576 // Ensure we can still send a new message after truncation.
1577 thread
1578 .update(cx, |thread, cx| {
1579 thread.send(UserMessageId::new(), ["Hi"], cx)
1580 })
1581 .unwrap();
1582 thread.update(cx, |thread, _cx| {
1583 assert_eq!(
1584 thread.to_markdown(),
1585 indoc! {"
1586 ## User
1587
1588 Hi
1589 "}
1590 );
1591 });
1592 cx.run_until_parked();
1593 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1594 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1595 language_model::TokenUsage {
1596 input_tokens: 40_000,
1597 output_tokens: 20_000,
1598 cache_creation_input_tokens: 0,
1599 cache_read_input_tokens: 0,
1600 },
1601 ));
1602 cx.run_until_parked();
1603 thread.read_with(cx, |thread, _| {
1604 assert_eq!(
1605 thread.to_markdown(),
1606 indoc! {"
1607 ## User
1608
1609 Hi
1610
1611 ## Assistant
1612
1613 Ahoy!
1614 "}
1615 );
1616
1617 assert_eq!(
1618 thread.latest_token_usage(),
1619 Some(acp_thread::TokenUsage {
1620 used_tokens: 40_000 + 20_000,
1621 max_tokens: 1_000_000,
1622 })
1623 );
1624 });
1625}
1626
1627#[gpui::test]
1628async fn test_truncate_second_message(cx: &mut TestAppContext) {
1629 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1630 let fake_model = model.as_fake();
1631
1632 thread
1633 .update(cx, |thread, cx| {
1634 thread.send(UserMessageId::new(), ["Message 1"], cx)
1635 })
1636 .unwrap();
1637 cx.run_until_parked();
1638 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1639 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1640 language_model::TokenUsage {
1641 input_tokens: 32_000,
1642 output_tokens: 16_000,
1643 cache_creation_input_tokens: 0,
1644 cache_read_input_tokens: 0,
1645 },
1646 ));
1647 fake_model.end_last_completion_stream();
1648 cx.run_until_parked();
1649
1650 let assert_first_message_state = |cx: &mut TestAppContext| {
1651 thread.clone().read_with(cx, |thread, _| {
1652 assert_eq!(
1653 thread.to_markdown(),
1654 indoc! {"
1655 ## User
1656
1657 Message 1
1658
1659 ## Assistant
1660
1661 Message 1 response
1662 "}
1663 );
1664
1665 assert_eq!(
1666 thread.latest_token_usage(),
1667 Some(acp_thread::TokenUsage {
1668 used_tokens: 32_000 + 16_000,
1669 max_tokens: 1_000_000,
1670 })
1671 );
1672 });
1673 };
1674
1675 assert_first_message_state(cx);
1676
1677 let second_message_id = UserMessageId::new();
1678 thread
1679 .update(cx, |thread, cx| {
1680 thread.send(second_message_id.clone(), ["Message 2"], cx)
1681 })
1682 .unwrap();
1683 cx.run_until_parked();
1684
1685 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1686 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1687 language_model::TokenUsage {
1688 input_tokens: 40_000,
1689 output_tokens: 20_000,
1690 cache_creation_input_tokens: 0,
1691 cache_read_input_tokens: 0,
1692 },
1693 ));
1694 fake_model.end_last_completion_stream();
1695 cx.run_until_parked();
1696
1697 thread.read_with(cx, |thread, _| {
1698 assert_eq!(
1699 thread.to_markdown(),
1700 indoc! {"
1701 ## User
1702
1703 Message 1
1704
1705 ## Assistant
1706
1707 Message 1 response
1708
1709 ## User
1710
1711 Message 2
1712
1713 ## Assistant
1714
1715 Message 2 response
1716 "}
1717 );
1718
1719 assert_eq!(
1720 thread.latest_token_usage(),
1721 Some(acp_thread::TokenUsage {
1722 used_tokens: 40_000 + 20_000,
1723 max_tokens: 1_000_000,
1724 })
1725 );
1726 });
1727
1728 thread
1729 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1730 .unwrap();
1731 cx.run_until_parked();
1732
1733 assert_first_message_state(cx);
1734}
1735
1736#[gpui::test]
1737async fn test_title_generation(cx: &mut TestAppContext) {
1738 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1739 let fake_model = model.as_fake();
1740
1741 let summary_model = Arc::new(FakeLanguageModel::default());
1742 thread.update(cx, |thread, cx| {
1743 thread.set_summarization_model(Some(summary_model.clone()), cx)
1744 });
1745
1746 let send = thread
1747 .update(cx, |thread, cx| {
1748 thread.send(UserMessageId::new(), ["Hello"], cx)
1749 })
1750 .unwrap();
1751 cx.run_until_parked();
1752
1753 fake_model.send_last_completion_stream_text_chunk("Hey!");
1754 fake_model.end_last_completion_stream();
1755 cx.run_until_parked();
1756 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1757
1758 // Ensure the summary model has been invoked to generate a title.
1759 summary_model.send_last_completion_stream_text_chunk("Hello ");
1760 summary_model.send_last_completion_stream_text_chunk("world\nG");
1761 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1762 summary_model.end_last_completion_stream();
1763 send.collect::<Vec<_>>().await;
1764 cx.run_until_parked();
1765 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1766
1767 // Send another message, ensuring no title is generated this time.
1768 let send = thread
1769 .update(cx, |thread, cx| {
1770 thread.send(UserMessageId::new(), ["Hello again"], cx)
1771 })
1772 .unwrap();
1773 cx.run_until_parked();
1774 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1775 fake_model.end_last_completion_stream();
1776 cx.run_until_parked();
1777 assert_eq!(summary_model.pending_completions(), Vec::new());
1778 send.collect::<Vec<_>>().await;
1779 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1780}
1781
1782#[gpui::test]
1783async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1784 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1785 let fake_model = model.as_fake();
1786
1787 let _events = thread
1788 .update(cx, |thread, cx| {
1789 thread.add_tool(ToolRequiringPermission);
1790 thread.add_tool(EchoTool);
1791 thread.send(UserMessageId::new(), ["Hey!"], cx)
1792 })
1793 .unwrap();
1794 cx.run_until_parked();
1795
1796 let permission_tool_use = LanguageModelToolUse {
1797 id: "tool_id_1".into(),
1798 name: ToolRequiringPermission::name().into(),
1799 raw_input: "{}".into(),
1800 input: json!({}),
1801 is_input_complete: true,
1802 thought_signature: None,
1803 };
1804 let echo_tool_use = LanguageModelToolUse {
1805 id: "tool_id_2".into(),
1806 name: EchoTool::name().into(),
1807 raw_input: json!({"text": "test"}).to_string(),
1808 input: json!({"text": "test"}),
1809 is_input_complete: true,
1810 thought_signature: None,
1811 };
1812 fake_model.send_last_completion_stream_text_chunk("Hi!");
1813 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1814 permission_tool_use,
1815 ));
1816 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1817 echo_tool_use.clone(),
1818 ));
1819 fake_model.end_last_completion_stream();
1820 cx.run_until_parked();
1821
1822 // Ensure pending tools are skipped when building a request.
1823 let request = thread
1824 .read_with(cx, |thread, cx| {
1825 thread.build_completion_request(CompletionIntent::EditFile, cx)
1826 })
1827 .unwrap();
1828 assert_eq!(
1829 request.messages[1..],
1830 vec![
1831 LanguageModelRequestMessage {
1832 role: Role::User,
1833 content: vec!["Hey!".into()],
1834 cache: true
1835 },
1836 LanguageModelRequestMessage {
1837 role: Role::Assistant,
1838 content: vec![
1839 MessageContent::Text("Hi!".into()),
1840 MessageContent::ToolUse(echo_tool_use.clone())
1841 ],
1842 cache: false
1843 },
1844 LanguageModelRequestMessage {
1845 role: Role::User,
1846 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1847 tool_use_id: echo_tool_use.id.clone(),
1848 tool_name: echo_tool_use.name,
1849 is_error: false,
1850 content: "test".into(),
1851 output: Some("test".into())
1852 })],
1853 cache: false
1854 },
1855 ],
1856 );
1857}
1858
1859#[gpui::test]
1860async fn test_agent_connection(cx: &mut TestAppContext) {
1861 cx.update(settings::init);
1862 let templates = Templates::new();
1863
1864 // Initialize language model system with test provider
1865 cx.update(|cx| {
1866 gpui_tokio::init(cx);
1867
1868 let http_client = FakeHttpClient::with_404_response();
1869 let clock = Arc::new(clock::FakeSystemClock::new());
1870 let client = Client::new(clock, http_client, cx);
1871 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1872 language_model::init(client.clone(), cx);
1873 language_models::init(user_store, client.clone(), cx);
1874 LanguageModelRegistry::test(cx);
1875 });
1876 cx.executor().forbid_parking();
1877
1878 // Create a project for new_thread
1879 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1880 fake_fs.insert_tree(path!("/test"), json!({})).await;
1881 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1882 let cwd = Path::new("/test");
1883 let text_thread_store =
1884 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1885 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1886
1887 // Create agent and connection
1888 let agent = NativeAgent::new(
1889 project.clone(),
1890 history_store,
1891 templates.clone(),
1892 None,
1893 fake_fs.clone(),
1894 &mut cx.to_async(),
1895 )
1896 .await
1897 .unwrap();
1898 let connection = NativeAgentConnection(agent.clone());
1899
1900 // Create a thread using new_thread
1901 let connection_rc = Rc::new(connection.clone());
1902 let acp_thread = cx
1903 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1904 .await
1905 .expect("new_thread should succeed");
1906
1907 // Get the session_id from the AcpThread
1908 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1909
1910 // Test model_selector returns Some
1911 let selector_opt = connection.model_selector(&session_id);
1912 assert!(
1913 selector_opt.is_some(),
1914 "agent should always support ModelSelector"
1915 );
1916 let selector = selector_opt.unwrap();
1917
1918 // Test list_models
1919 let listed_models = cx
1920 .update(|cx| selector.list_models(cx))
1921 .await
1922 .expect("list_models should succeed");
1923 let AgentModelList::Grouped(listed_models) = listed_models else {
1924 panic!("Unexpected model list type");
1925 };
1926 assert!(!listed_models.is_empty(), "should have at least one model");
1927 assert_eq!(
1928 listed_models[&AgentModelGroupName("Fake".into())][0]
1929 .id
1930 .0
1931 .as_ref(),
1932 "fake/fake"
1933 );
1934
1935 // Test selected_model returns the default
1936 let model = cx
1937 .update(|cx| selector.selected_model(cx))
1938 .await
1939 .expect("selected_model should succeed");
1940 let model = cx
1941 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1942 .unwrap();
1943 let model = model.as_fake();
1944 assert_eq!(model.id().0, "fake", "should return default model");
1945
1946 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1947 cx.run_until_parked();
1948 model.send_last_completion_stream_text_chunk("def");
1949 cx.run_until_parked();
1950 acp_thread.read_with(cx, |thread, cx| {
1951 assert_eq!(
1952 thread.to_markdown(cx),
1953 indoc! {"
1954 ## User
1955
1956 abc
1957
1958 ## Assistant
1959
1960 def
1961
1962 "}
1963 )
1964 });
1965
1966 // Test cancel
1967 cx.update(|cx| connection.cancel(&session_id, cx));
1968 request.await.expect("prompt should fail gracefully");
1969
1970 // Ensure that dropping the ACP thread causes the native thread to be
1971 // dropped as well.
1972 cx.update(|_| drop(acp_thread));
1973 let result = cx
1974 .update(|cx| {
1975 connection.prompt(
1976 Some(acp_thread::UserMessageId::new()),
1977 acp::PromptRequest {
1978 session_id: session_id.clone(),
1979 prompt: vec!["ghi".into()],
1980 meta: None,
1981 },
1982 cx,
1983 )
1984 })
1985 .await;
1986 assert_eq!(
1987 result.as_ref().unwrap_err().to_string(),
1988 "Session not found",
1989 "unexpected result: {:?}",
1990 result
1991 );
1992}
1993
1994#[gpui::test]
1995async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1996 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1997 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1998 let fake_model = model.as_fake();
1999
2000 let mut events = thread
2001 .update(cx, |thread, cx| {
2002 thread.send(UserMessageId::new(), ["Think"], cx)
2003 })
2004 .unwrap();
2005 cx.run_until_parked();
2006
2007 // Simulate streaming partial input.
2008 let input = json!({});
2009 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2010 LanguageModelToolUse {
2011 id: "1".into(),
2012 name: ThinkingTool::name().into(),
2013 raw_input: input.to_string(),
2014 input,
2015 is_input_complete: false,
2016 thought_signature: None,
2017 },
2018 ));
2019
2020 // Input streaming completed
2021 let input = json!({ "content": "Thinking hard!" });
2022 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2023 LanguageModelToolUse {
2024 id: "1".into(),
2025 name: "thinking".into(),
2026 raw_input: input.to_string(),
2027 input,
2028 is_input_complete: true,
2029 thought_signature: None,
2030 },
2031 ));
2032 fake_model.end_last_completion_stream();
2033 cx.run_until_parked();
2034
2035 let tool_call = expect_tool_call(&mut events).await;
2036 assert_eq!(
2037 tool_call,
2038 acp::ToolCall {
2039 id: acp::ToolCallId("1".into()),
2040 title: "Thinking".into(),
2041 kind: acp::ToolKind::Think,
2042 status: acp::ToolCallStatus::Pending,
2043 content: vec![],
2044 locations: vec![],
2045 raw_input: Some(json!({})),
2046 raw_output: None,
2047 meta: Some(json!({ "tool_name": "thinking" })),
2048 }
2049 );
2050 let update = expect_tool_call_update_fields(&mut events).await;
2051 assert_eq!(
2052 update,
2053 acp::ToolCallUpdate {
2054 id: acp::ToolCallId("1".into()),
2055 fields: acp::ToolCallUpdateFields {
2056 title: Some("Thinking".into()),
2057 kind: Some(acp::ToolKind::Think),
2058 raw_input: Some(json!({ "content": "Thinking hard!" })),
2059 ..Default::default()
2060 },
2061 meta: None,
2062 }
2063 );
2064 let update = expect_tool_call_update_fields(&mut events).await;
2065 assert_eq!(
2066 update,
2067 acp::ToolCallUpdate {
2068 id: acp::ToolCallId("1".into()),
2069 fields: acp::ToolCallUpdateFields {
2070 status: Some(acp::ToolCallStatus::InProgress),
2071 ..Default::default()
2072 },
2073 meta: None,
2074 }
2075 );
2076 let update = expect_tool_call_update_fields(&mut events).await;
2077 assert_eq!(
2078 update,
2079 acp::ToolCallUpdate {
2080 id: acp::ToolCallId("1".into()),
2081 fields: acp::ToolCallUpdateFields {
2082 content: Some(vec!["Thinking hard!".into()]),
2083 ..Default::default()
2084 },
2085 meta: None,
2086 }
2087 );
2088 let update = expect_tool_call_update_fields(&mut events).await;
2089 assert_eq!(
2090 update,
2091 acp::ToolCallUpdate {
2092 id: acp::ToolCallId("1".into()),
2093 fields: acp::ToolCallUpdateFields {
2094 status: Some(acp::ToolCallStatus::Completed),
2095 raw_output: Some("Finished thinking.".into()),
2096 ..Default::default()
2097 },
2098 meta: None,
2099 }
2100 );
2101}
2102
2103#[gpui::test]
2104async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2105 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2106 let fake_model = model.as_fake();
2107
2108 let mut events = thread
2109 .update(cx, |thread, cx| {
2110 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2111 thread.send(UserMessageId::new(), ["Hello!"], cx)
2112 })
2113 .unwrap();
2114 cx.run_until_parked();
2115
2116 fake_model.send_last_completion_stream_text_chunk("Hey!");
2117 fake_model.end_last_completion_stream();
2118
2119 let mut retry_events = Vec::new();
2120 while let Some(Ok(event)) = events.next().await {
2121 match event {
2122 ThreadEvent::Retry(retry_status) => {
2123 retry_events.push(retry_status);
2124 }
2125 ThreadEvent::Stop(..) => break,
2126 _ => {}
2127 }
2128 }
2129
2130 assert_eq!(retry_events.len(), 0);
2131 thread.read_with(cx, |thread, _cx| {
2132 assert_eq!(
2133 thread.to_markdown(),
2134 indoc! {"
2135 ## User
2136
2137 Hello!
2138
2139 ## Assistant
2140
2141 Hey!
2142 "}
2143 )
2144 });
2145}
2146
2147#[gpui::test]
2148async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2149 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2150 let fake_model = model.as_fake();
2151
2152 let mut events = thread
2153 .update(cx, |thread, cx| {
2154 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2155 thread.send(UserMessageId::new(), ["Hello!"], cx)
2156 })
2157 .unwrap();
2158 cx.run_until_parked();
2159
2160 fake_model.send_last_completion_stream_text_chunk("Hey,");
2161 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2162 provider: LanguageModelProviderName::new("Anthropic"),
2163 retry_after: Some(Duration::from_secs(3)),
2164 });
2165 fake_model.end_last_completion_stream();
2166
2167 cx.executor().advance_clock(Duration::from_secs(3));
2168 cx.run_until_parked();
2169
2170 fake_model.send_last_completion_stream_text_chunk("there!");
2171 fake_model.end_last_completion_stream();
2172 cx.run_until_parked();
2173
2174 let mut retry_events = Vec::new();
2175 while let Some(Ok(event)) = events.next().await {
2176 match event {
2177 ThreadEvent::Retry(retry_status) => {
2178 retry_events.push(retry_status);
2179 }
2180 ThreadEvent::Stop(..) => break,
2181 _ => {}
2182 }
2183 }
2184
2185 assert_eq!(retry_events.len(), 1);
2186 assert!(matches!(
2187 retry_events[0],
2188 acp_thread::RetryStatus { attempt: 1, .. }
2189 ));
2190 thread.read_with(cx, |thread, _cx| {
2191 assert_eq!(
2192 thread.to_markdown(),
2193 indoc! {"
2194 ## User
2195
2196 Hello!
2197
2198 ## Assistant
2199
2200 Hey,
2201
2202 [resume]
2203
2204 ## Assistant
2205
2206 there!
2207 "}
2208 )
2209 });
2210}
2211
2212#[gpui::test]
2213async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2214 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2215 let fake_model = model.as_fake();
2216
2217 let events = thread
2218 .update(cx, |thread, cx| {
2219 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2220 thread.add_tool(EchoTool);
2221 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2222 })
2223 .unwrap();
2224 cx.run_until_parked();
2225
2226 let tool_use_1 = LanguageModelToolUse {
2227 id: "tool_1".into(),
2228 name: EchoTool::name().into(),
2229 raw_input: json!({"text": "test"}).to_string(),
2230 input: json!({"text": "test"}),
2231 is_input_complete: true,
2232 thought_signature: None,
2233 };
2234 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2235 tool_use_1.clone(),
2236 ));
2237 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2238 provider: LanguageModelProviderName::new("Anthropic"),
2239 retry_after: Some(Duration::from_secs(3)),
2240 });
2241 fake_model.end_last_completion_stream();
2242
2243 cx.executor().advance_clock(Duration::from_secs(3));
2244 let completion = fake_model.pending_completions().pop().unwrap();
2245 assert_eq!(
2246 completion.messages[1..],
2247 vec![
2248 LanguageModelRequestMessage {
2249 role: Role::User,
2250 content: vec!["Call the echo tool!".into()],
2251 cache: false
2252 },
2253 LanguageModelRequestMessage {
2254 role: Role::Assistant,
2255 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2256 cache: false
2257 },
2258 LanguageModelRequestMessage {
2259 role: Role::User,
2260 content: vec![language_model::MessageContent::ToolResult(
2261 LanguageModelToolResult {
2262 tool_use_id: tool_use_1.id.clone(),
2263 tool_name: tool_use_1.name.clone(),
2264 is_error: false,
2265 content: "test".into(),
2266 output: Some("test".into())
2267 }
2268 )],
2269 cache: true
2270 },
2271 ]
2272 );
2273
2274 fake_model.send_last_completion_stream_text_chunk("Done");
2275 fake_model.end_last_completion_stream();
2276 cx.run_until_parked();
2277 events.collect::<Vec<_>>().await;
2278 thread.read_with(cx, |thread, _cx| {
2279 assert_eq!(
2280 thread.last_message(),
2281 Some(Message::Agent(AgentMessage {
2282 content: vec![AgentMessageContent::Text("Done".into())],
2283 tool_results: IndexMap::default()
2284 }))
2285 );
2286 })
2287}
2288
2289#[gpui::test]
2290async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2291 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2292 let fake_model = model.as_fake();
2293
2294 let mut events = thread
2295 .update(cx, |thread, cx| {
2296 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2297 thread.send(UserMessageId::new(), ["Hello!"], cx)
2298 })
2299 .unwrap();
2300 cx.run_until_parked();
2301
2302 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2303 fake_model.send_last_completion_stream_error(
2304 LanguageModelCompletionError::ServerOverloaded {
2305 provider: LanguageModelProviderName::new("Anthropic"),
2306 retry_after: Some(Duration::from_secs(3)),
2307 },
2308 );
2309 fake_model.end_last_completion_stream();
2310 cx.executor().advance_clock(Duration::from_secs(3));
2311 cx.run_until_parked();
2312 }
2313
2314 let mut errors = Vec::new();
2315 let mut retry_events = Vec::new();
2316 while let Some(event) = events.next().await {
2317 match event {
2318 Ok(ThreadEvent::Retry(retry_status)) => {
2319 retry_events.push(retry_status);
2320 }
2321 Ok(ThreadEvent::Stop(..)) => break,
2322 Err(error) => errors.push(error),
2323 _ => {}
2324 }
2325 }
2326
2327 assert_eq!(
2328 retry_events.len(),
2329 crate::thread::MAX_RETRY_ATTEMPTS as usize
2330 );
2331 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2332 assert_eq!(retry_events[i].attempt, i + 1);
2333 }
2334 assert_eq!(errors.len(), 1);
2335 let error = errors[0]
2336 .downcast_ref::<LanguageModelCompletionError>()
2337 .unwrap();
2338 assert!(matches!(
2339 error,
2340 LanguageModelCompletionError::ServerOverloaded { .. }
2341 ));
2342}
2343
2344/// Filters out the stop events for asserting against in tests
2345fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2346 result_events
2347 .into_iter()
2348 .filter_map(|event| match event.unwrap() {
2349 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2350 _ => None,
2351 })
2352 .collect()
2353}
2354
2355struct ThreadTest {
2356 model: Arc<dyn LanguageModel>,
2357 thread: Entity<Thread>,
2358 project_context: Entity<ProjectContext>,
2359 context_server_store: Entity<ContextServerStore>,
2360 fs: Arc<FakeFs>,
2361}
2362
2363enum TestModel {
2364 Sonnet4,
2365 Fake,
2366}
2367
2368impl TestModel {
2369 fn id(&self) -> LanguageModelId {
2370 match self {
2371 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2372 TestModel::Fake => unreachable!(),
2373 }
2374 }
2375}
2376
2377async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2378 cx.executor().allow_parking();
2379
2380 let fs = FakeFs::new(cx.background_executor.clone());
2381 fs.create_dir(paths::settings_file().parent().unwrap())
2382 .await
2383 .unwrap();
2384 fs.insert_file(
2385 paths::settings_file(),
2386 json!({
2387 "agent": {
2388 "default_profile": "test-profile",
2389 "profiles": {
2390 "test-profile": {
2391 "name": "Test Profile",
2392 "tools": {
2393 EchoTool::name(): true,
2394 DelayTool::name(): true,
2395 WordListTool::name(): true,
2396 ToolRequiringPermission::name(): true,
2397 InfiniteTool::name(): true,
2398 ThinkingTool::name(): true,
2399 }
2400 }
2401 }
2402 }
2403 })
2404 .to_string()
2405 .into_bytes(),
2406 )
2407 .await;
2408
2409 cx.update(|cx| {
2410 settings::init(cx);
2411
2412 match model {
2413 TestModel::Fake => {}
2414 TestModel::Sonnet4 => {
2415 gpui_tokio::init(cx);
2416 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2417 cx.set_http_client(Arc::new(http_client));
2418 let client = Client::production(cx);
2419 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2420 language_model::init(client.clone(), cx);
2421 language_models::init(user_store, client.clone(), cx);
2422 }
2423 };
2424
2425 watch_settings(fs.clone(), cx);
2426 });
2427
2428 let templates = Templates::new();
2429
2430 fs.insert_tree(path!("/test"), json!({})).await;
2431 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2432
2433 let model = cx
2434 .update(|cx| {
2435 if let TestModel::Fake = model {
2436 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2437 } else {
2438 let model_id = model.id();
2439 let models = LanguageModelRegistry::read_global(cx);
2440 let model = models
2441 .available_models(cx)
2442 .find(|model| model.id() == model_id)
2443 .unwrap();
2444
2445 let provider = models.provider(&model.provider_id()).unwrap();
2446 let authenticated = provider.authenticate(cx);
2447
2448 cx.spawn(async move |_cx| {
2449 authenticated.await.unwrap();
2450 model
2451 })
2452 }
2453 })
2454 .await;
2455
2456 let project_context = cx.new(|_cx| ProjectContext::default());
2457 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2458 let context_server_registry =
2459 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2460 let thread = cx.new(|cx| {
2461 Thread::new(
2462 project,
2463 project_context.clone(),
2464 context_server_registry,
2465 templates,
2466 Some(model.clone()),
2467 cx,
2468 )
2469 });
2470 ThreadTest {
2471 model,
2472 thread,
2473 project_context,
2474 context_server_store,
2475 fs,
2476 }
2477}
2478
2479#[cfg(test)]
2480#[ctor::ctor]
2481fn init_logger() {
2482 if std::env::var("RUST_LOG").is_ok() {
2483 env_logger::init();
2484 }
2485}
2486
2487fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2488 let fs = fs.clone();
2489 cx.spawn({
2490 async move |cx| {
2491 let mut new_settings_content_rx = settings::watch_config_file(
2492 cx.background_executor(),
2493 fs,
2494 paths::settings_file().clone(),
2495 );
2496
2497 while let Some(new_settings_content) = new_settings_content_rx.next().await {
2498 cx.update(|cx| {
2499 SettingsStore::update_global(cx, |settings, cx| {
2500 settings.set_user_settings(&new_settings_content, cx)
2501 })
2502 })
2503 .ok();
2504 }
2505 }
2506 })
2507 .detach();
2508}
2509
2510fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2511 completion
2512 .tools
2513 .iter()
2514 .map(|tool| tool.name.clone())
2515 .collect()
2516}
2517
2518fn setup_context_server(
2519 name: &'static str,
2520 tools: Vec<context_server::types::Tool>,
2521 context_server_store: &Entity<ContextServerStore>,
2522 cx: &mut TestAppContext,
2523) -> mpsc::UnboundedReceiver<(
2524 context_server::types::CallToolParams,
2525 oneshot::Sender<context_server::types::CallToolResponse>,
2526)> {
2527 cx.update(|cx| {
2528 let mut settings = ProjectSettings::get_global(cx).clone();
2529 settings.context_servers.insert(
2530 name.into(),
2531 project::project_settings::ContextServerSettings::Custom {
2532 enabled: true,
2533 command: ContextServerCommand {
2534 path: "somebinary".into(),
2535 args: Vec::new(),
2536 env: None,
2537 timeout: None,
2538 },
2539 },
2540 );
2541 ProjectSettings::override_global(settings, cx);
2542 });
2543
2544 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2545 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2546 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2547 context_server::types::InitializeResponse {
2548 protocol_version: context_server::types::ProtocolVersion(
2549 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2550 ),
2551 server_info: context_server::types::Implementation {
2552 name: name.into(),
2553 version: "1.0.0".to_string(),
2554 },
2555 capabilities: context_server::types::ServerCapabilities {
2556 tools: Some(context_server::types::ToolsCapabilities {
2557 list_changed: Some(true),
2558 }),
2559 ..Default::default()
2560 },
2561 meta: None,
2562 }
2563 })
2564 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2565 let tools = tools.clone();
2566 async move {
2567 context_server::types::ListToolsResponse {
2568 tools,
2569 next_cursor: None,
2570 meta: None,
2571 }
2572 }
2573 })
2574 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2575 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2576 async move {
2577 let (response_tx, response_rx) = oneshot::channel();
2578 mcp_tool_calls_tx
2579 .unbounded_send((params, response_tx))
2580 .unwrap();
2581 response_rx.await.unwrap()
2582 }
2583 });
2584 context_server_store.update(cx, |store, cx| {
2585 store.start_server(
2586 Arc::new(ContextServer::new(
2587 ContextServerId(name.into()),
2588 Arc::new(fake_transport),
2589 )),
2590 cx,
2591 );
2592 });
2593 cx.run_until_parked();
2594 mcp_tool_calls_rx
2595}