1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use fs::{FakeFs, Fs};
11use futures::{
12 StreamExt,
13 channel::{
14 mpsc::{self, UnboundedReceiver},
15 oneshot,
16 },
17};
18use gpui::{
19 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
20};
21use indoc::indoc;
22use language_model::{
23 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
24 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
25 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
26 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
27};
28use pretty_assertions::assert_eq;
29use project::{
30 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
31};
32use prompt_store::ProjectContext;
33use reqwest_client::ReqwestClient;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use serde_json::json;
37use settings::{Settings, SettingsStore};
38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
39use util::path;
40
41mod test_tools;
42use test_tools::*;
43
44#[gpui::test]
45async fn test_echo(cx: &mut TestAppContext) {
46 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
47 let fake_model = model.as_fake();
48
49 let events = thread
50 .update(cx, |thread, cx| {
51 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
52 })
53 .unwrap();
54 cx.run_until_parked();
55 fake_model.send_last_completion_stream_text_chunk("Hello");
56 fake_model
57 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
58 fake_model.end_last_completion_stream();
59
60 let events = events.collect().await;
61 thread.update(cx, |thread, _cx| {
62 assert_eq!(
63 thread.last_message().unwrap().to_markdown(),
64 indoc! {"
65 ## Assistant
66
67 Hello
68 "}
69 )
70 });
71 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
72}
73
74#[gpui::test]
75async fn test_thinking(cx: &mut TestAppContext) {
76 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
77 let fake_model = model.as_fake();
78
79 let events = thread
80 .update(cx, |thread, cx| {
81 thread.send(
82 UserMessageId::new(),
83 [indoc! {"
84 Testing:
85
86 Generate a thinking step where you just think the word 'Think',
87 and have your final answer be 'Hello'
88 "}],
89 cx,
90 )
91 })
92 .unwrap();
93 cx.run_until_parked();
94 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
95 text: "Think".to_string(),
96 signature: None,
97 });
98 fake_model.send_last_completion_stream_text_chunk("Hello");
99 fake_model
100 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
101 fake_model.end_last_completion_stream();
102
103 let events = events.collect().await;
104 thread.update(cx, |thread, _cx| {
105 assert_eq!(
106 thread.last_message().unwrap().to_markdown(),
107 indoc! {"
108 ## Assistant
109
110 <think>Think</think>
111 Hello
112 "}
113 )
114 });
115 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
116}
117
118#[gpui::test]
119async fn test_system_prompt(cx: &mut TestAppContext) {
120 let ThreadTest {
121 model,
122 thread,
123 project_context,
124 ..
125 } = setup(cx, TestModel::Fake).await;
126 let fake_model = model.as_fake();
127
128 project_context.update(cx, |project_context, _cx| {
129 project_context.shell = "test-shell".into()
130 });
131 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
132 thread
133 .update(cx, |thread, cx| {
134 thread.send(UserMessageId::new(), ["abc"], cx)
135 })
136 .unwrap();
137 cx.run_until_parked();
138 let mut pending_completions = fake_model.pending_completions();
139 assert_eq!(
140 pending_completions.len(),
141 1,
142 "unexpected pending completions: {:?}",
143 pending_completions
144 );
145
146 let pending_completion = pending_completions.pop().unwrap();
147 assert_eq!(pending_completion.messages[0].role, Role::System);
148
149 let system_message = &pending_completion.messages[0];
150 let system_prompt = system_message.content[0].to_str().unwrap();
151 assert!(
152 system_prompt.contains("test-shell"),
153 "unexpected system message: {:?}",
154 system_message
155 );
156 assert!(
157 system_prompt.contains("## Fixing Diagnostics"),
158 "unexpected system message: {:?}",
159 system_message
160 );
161}
162
163#[gpui::test]
164async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
165 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
166 let fake_model = model.as_fake();
167
168 thread
169 .update(cx, |thread, cx| {
170 thread.send(UserMessageId::new(), ["abc"], cx)
171 })
172 .unwrap();
173 cx.run_until_parked();
174 let mut pending_completions = fake_model.pending_completions();
175 assert_eq!(
176 pending_completions.len(),
177 1,
178 "unexpected pending completions: {:?}",
179 pending_completions
180 );
181
182 let pending_completion = pending_completions.pop().unwrap();
183 assert_eq!(pending_completion.messages[0].role, Role::System);
184
185 let system_message = &pending_completion.messages[0];
186 let system_prompt = system_message.content[0].to_str().unwrap();
187 assert!(
188 !system_prompt.contains("## Tool Use"),
189 "unexpected system message: {:?}",
190 system_message
191 );
192 assert!(
193 !system_prompt.contains("## Fixing Diagnostics"),
194 "unexpected system message: {:?}",
195 system_message
196 );
197}
198
199#[gpui::test]
200async fn test_prompt_caching(cx: &mut TestAppContext) {
201 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
202 let fake_model = model.as_fake();
203
204 // Send initial user message and verify it's cached
205 thread
206 .update(cx, |thread, cx| {
207 thread.send(UserMessageId::new(), ["Message 1"], cx)
208 })
209 .unwrap();
210 cx.run_until_parked();
211
212 let completion = fake_model.pending_completions().pop().unwrap();
213 assert_eq!(
214 completion.messages[1..],
215 vec![LanguageModelRequestMessage {
216 role: Role::User,
217 content: vec!["Message 1".into()],
218 cache: true
219 }]
220 );
221 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
222 "Response to Message 1".into(),
223 ));
224 fake_model.end_last_completion_stream();
225 cx.run_until_parked();
226
227 // Send another user message and verify only the latest is cached
228 thread
229 .update(cx, |thread, cx| {
230 thread.send(UserMessageId::new(), ["Message 2"], cx)
231 })
232 .unwrap();
233 cx.run_until_parked();
234
235 let completion = fake_model.pending_completions().pop().unwrap();
236 assert_eq!(
237 completion.messages[1..],
238 vec![
239 LanguageModelRequestMessage {
240 role: Role::User,
241 content: vec!["Message 1".into()],
242 cache: false
243 },
244 LanguageModelRequestMessage {
245 role: Role::Assistant,
246 content: vec!["Response to Message 1".into()],
247 cache: false
248 },
249 LanguageModelRequestMessage {
250 role: Role::User,
251 content: vec!["Message 2".into()],
252 cache: true
253 }
254 ]
255 );
256 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
257 "Response to Message 2".into(),
258 ));
259 fake_model.end_last_completion_stream();
260 cx.run_until_parked();
261
262 // Simulate a tool call and verify that the latest tool result is cached
263 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
264 thread
265 .update(cx, |thread, cx| {
266 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
267 })
268 .unwrap();
269 cx.run_until_parked();
270
271 let tool_use = LanguageModelToolUse {
272 id: "tool_1".into(),
273 name: EchoTool::name().into(),
274 raw_input: json!({"text": "test"}).to_string(),
275 input: json!({"text": "test"}),
276 is_input_complete: true,
277 };
278 fake_model
279 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
280 fake_model.end_last_completion_stream();
281 cx.run_until_parked();
282
283 let completion = fake_model.pending_completions().pop().unwrap();
284 let tool_result = LanguageModelToolResult {
285 tool_use_id: "tool_1".into(),
286 tool_name: EchoTool::name().into(),
287 is_error: false,
288 content: "test".into(),
289 output: Some("test".into()),
290 };
291 assert_eq!(
292 completion.messages[1..],
293 vec![
294 LanguageModelRequestMessage {
295 role: Role::User,
296 content: vec!["Message 1".into()],
297 cache: false
298 },
299 LanguageModelRequestMessage {
300 role: Role::Assistant,
301 content: vec!["Response to Message 1".into()],
302 cache: false
303 },
304 LanguageModelRequestMessage {
305 role: Role::User,
306 content: vec!["Message 2".into()],
307 cache: false
308 },
309 LanguageModelRequestMessage {
310 role: Role::Assistant,
311 content: vec!["Response to Message 2".into()],
312 cache: false
313 },
314 LanguageModelRequestMessage {
315 role: Role::User,
316 content: vec!["Use the echo tool".into()],
317 cache: false
318 },
319 LanguageModelRequestMessage {
320 role: Role::Assistant,
321 content: vec![MessageContent::ToolUse(tool_use)],
322 cache: false
323 },
324 LanguageModelRequestMessage {
325 role: Role::User,
326 content: vec![MessageContent::ToolResult(tool_result)],
327 cache: true
328 }
329 ]
330 );
331}
332
333#[gpui::test]
334#[cfg_attr(not(feature = "e2e"), ignore)]
335async fn test_basic_tool_calls(cx: &mut TestAppContext) {
336 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
337
338 // Test a tool call that's likely to complete *before* streaming stops.
339 let events = thread
340 .update(cx, |thread, cx| {
341 thread.add_tool(EchoTool);
342 thread.send(
343 UserMessageId::new(),
344 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
345 cx,
346 )
347 })
348 .unwrap()
349 .collect()
350 .await;
351 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
352
353 // Test a tool calls that's likely to complete *after* streaming stops.
354 let events = thread
355 .update(cx, |thread, cx| {
356 thread.remove_tool(&EchoTool::name());
357 thread.add_tool(DelayTool);
358 thread.send(
359 UserMessageId::new(),
360 [
361 "Now call the delay tool with 200ms.",
362 "When the timer goes off, then you echo the output of the tool.",
363 ],
364 cx,
365 )
366 })
367 .unwrap()
368 .collect()
369 .await;
370 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
371 thread.update(cx, |thread, _cx| {
372 assert!(
373 thread
374 .last_message()
375 .unwrap()
376 .as_agent_message()
377 .unwrap()
378 .content
379 .iter()
380 .any(|content| {
381 if let AgentMessageContent::Text(text) = content {
382 text.contains("Ding")
383 } else {
384 false
385 }
386 }),
387 "{}",
388 thread.to_markdown()
389 );
390 });
391}
392
393#[gpui::test]
394#[cfg_attr(not(feature = "e2e"), ignore)]
395async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
396 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
397
398 // Test a tool call that's likely to complete *before* streaming stops.
399 let mut events = thread
400 .update(cx, |thread, cx| {
401 thread.add_tool(WordListTool);
402 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
403 })
404 .unwrap();
405
406 let mut saw_partial_tool_use = false;
407 while let Some(event) = events.next().await {
408 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
409 thread.update(cx, |thread, _cx| {
410 // Look for a tool use in the thread's last message
411 let message = thread.last_message().unwrap();
412 let agent_message = message.as_agent_message().unwrap();
413 let last_content = agent_message.content.last().unwrap();
414 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
415 assert_eq!(last_tool_use.name.as_ref(), "word_list");
416 if tool_call.status == acp::ToolCallStatus::Pending {
417 if !last_tool_use.is_input_complete
418 && last_tool_use.input.get("g").is_none()
419 {
420 saw_partial_tool_use = true;
421 }
422 } else {
423 last_tool_use
424 .input
425 .get("a")
426 .expect("'a' has streamed because input is now complete");
427 last_tool_use
428 .input
429 .get("g")
430 .expect("'g' has streamed because input is now complete");
431 }
432 } else {
433 panic!("last content should be a tool use");
434 }
435 });
436 }
437 }
438
439 assert!(
440 saw_partial_tool_use,
441 "should see at least one partially streamed tool use in the history"
442 );
443}
444
445#[gpui::test]
446async fn test_tool_authorization(cx: &mut TestAppContext) {
447 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
448 let fake_model = model.as_fake();
449
450 let mut events = thread
451 .update(cx, |thread, cx| {
452 thread.add_tool(ToolRequiringPermission);
453 thread.send(UserMessageId::new(), ["abc"], cx)
454 })
455 .unwrap();
456 cx.run_until_parked();
457 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
458 LanguageModelToolUse {
459 id: "tool_id_1".into(),
460 name: ToolRequiringPermission::name().into(),
461 raw_input: "{}".into(),
462 input: json!({}),
463 is_input_complete: true,
464 },
465 ));
466 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
467 LanguageModelToolUse {
468 id: "tool_id_2".into(),
469 name: ToolRequiringPermission::name().into(),
470 raw_input: "{}".into(),
471 input: json!({}),
472 is_input_complete: true,
473 },
474 ));
475 fake_model.end_last_completion_stream();
476 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
477 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
478
479 // Approve the first
480 tool_call_auth_1
481 .response
482 .send(tool_call_auth_1.options[1].id.clone())
483 .unwrap();
484 cx.run_until_parked();
485
486 // Reject the second
487 tool_call_auth_2
488 .response
489 .send(tool_call_auth_1.options[2].id.clone())
490 .unwrap();
491 cx.run_until_parked();
492
493 let completion = fake_model.pending_completions().pop().unwrap();
494 let message = completion.messages.last().unwrap();
495 assert_eq!(
496 message.content,
497 vec![
498 language_model::MessageContent::ToolResult(LanguageModelToolResult {
499 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
500 tool_name: ToolRequiringPermission::name().into(),
501 is_error: false,
502 content: "Allowed".into(),
503 output: Some("Allowed".into())
504 }),
505 language_model::MessageContent::ToolResult(LanguageModelToolResult {
506 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
507 tool_name: ToolRequiringPermission::name().into(),
508 is_error: true,
509 content: "Permission to run tool denied by user".into(),
510 output: Some("Permission to run tool denied by user".into())
511 })
512 ]
513 );
514
515 // Simulate yet another tool call.
516 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
517 LanguageModelToolUse {
518 id: "tool_id_3".into(),
519 name: ToolRequiringPermission::name().into(),
520 raw_input: "{}".into(),
521 input: json!({}),
522 is_input_complete: true,
523 },
524 ));
525 fake_model.end_last_completion_stream();
526
527 // Respond by always allowing tools.
528 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
529 tool_call_auth_3
530 .response
531 .send(tool_call_auth_3.options[0].id.clone())
532 .unwrap();
533 cx.run_until_parked();
534 let completion = fake_model.pending_completions().pop().unwrap();
535 let message = completion.messages.last().unwrap();
536 assert_eq!(
537 message.content,
538 vec![language_model::MessageContent::ToolResult(
539 LanguageModelToolResult {
540 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
541 tool_name: ToolRequiringPermission::name().into(),
542 is_error: false,
543 content: "Allowed".into(),
544 output: Some("Allowed".into())
545 }
546 )]
547 );
548
549 // Simulate a final tool call, ensuring we don't trigger authorization.
550 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
551 LanguageModelToolUse {
552 id: "tool_id_4".into(),
553 name: ToolRequiringPermission::name().into(),
554 raw_input: "{}".into(),
555 input: json!({}),
556 is_input_complete: true,
557 },
558 ));
559 fake_model.end_last_completion_stream();
560 cx.run_until_parked();
561 let completion = fake_model.pending_completions().pop().unwrap();
562 let message = completion.messages.last().unwrap();
563 assert_eq!(
564 message.content,
565 vec![language_model::MessageContent::ToolResult(
566 LanguageModelToolResult {
567 tool_use_id: "tool_id_4".into(),
568 tool_name: ToolRequiringPermission::name().into(),
569 is_error: false,
570 content: "Allowed".into(),
571 output: Some("Allowed".into())
572 }
573 )]
574 );
575}
576
577#[gpui::test]
578async fn test_tool_hallucination(cx: &mut TestAppContext) {
579 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
580 let fake_model = model.as_fake();
581
582 let mut events = thread
583 .update(cx, |thread, cx| {
584 thread.send(UserMessageId::new(), ["abc"], cx)
585 })
586 .unwrap();
587 cx.run_until_parked();
588 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
589 LanguageModelToolUse {
590 id: "tool_id_1".into(),
591 name: "nonexistent_tool".into(),
592 raw_input: "{}".into(),
593 input: json!({}),
594 is_input_complete: true,
595 },
596 ));
597 fake_model.end_last_completion_stream();
598
599 let tool_call = expect_tool_call(&mut events).await;
600 assert_eq!(tool_call.title, "nonexistent_tool");
601 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
602 let update = expect_tool_call_update_fields(&mut events).await;
603 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
604}
605
606#[gpui::test]
607async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
608 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
609 let fake_model = model.as_fake();
610
611 let events = thread
612 .update(cx, |thread, cx| {
613 thread.add_tool(EchoTool);
614 thread.send(UserMessageId::new(), ["abc"], cx)
615 })
616 .unwrap();
617 cx.run_until_parked();
618 let tool_use = LanguageModelToolUse {
619 id: "tool_id_1".into(),
620 name: EchoTool::name().into(),
621 raw_input: "{}".into(),
622 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
623 is_input_complete: true,
624 };
625 fake_model
626 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
627 fake_model.end_last_completion_stream();
628
629 cx.run_until_parked();
630 let completion = fake_model.pending_completions().pop().unwrap();
631 let tool_result = LanguageModelToolResult {
632 tool_use_id: "tool_id_1".into(),
633 tool_name: EchoTool::name().into(),
634 is_error: false,
635 content: "def".into(),
636 output: Some("def".into()),
637 };
638 assert_eq!(
639 completion.messages[1..],
640 vec![
641 LanguageModelRequestMessage {
642 role: Role::User,
643 content: vec!["abc".into()],
644 cache: false
645 },
646 LanguageModelRequestMessage {
647 role: Role::Assistant,
648 content: vec![MessageContent::ToolUse(tool_use.clone())],
649 cache: false
650 },
651 LanguageModelRequestMessage {
652 role: Role::User,
653 content: vec![MessageContent::ToolResult(tool_result.clone())],
654 cache: true
655 },
656 ]
657 );
658
659 // Simulate reaching tool use limit.
660 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
661 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
662 ));
663 fake_model.end_last_completion_stream();
664 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
665 assert!(
666 last_event
667 .unwrap_err()
668 .is::<language_model::ToolUseLimitReachedError>()
669 );
670
671 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
672 cx.run_until_parked();
673 let completion = fake_model.pending_completions().pop().unwrap();
674 assert_eq!(
675 completion.messages[1..],
676 vec![
677 LanguageModelRequestMessage {
678 role: Role::User,
679 content: vec!["abc".into()],
680 cache: false
681 },
682 LanguageModelRequestMessage {
683 role: Role::Assistant,
684 content: vec![MessageContent::ToolUse(tool_use)],
685 cache: false
686 },
687 LanguageModelRequestMessage {
688 role: Role::User,
689 content: vec![MessageContent::ToolResult(tool_result)],
690 cache: false
691 },
692 LanguageModelRequestMessage {
693 role: Role::User,
694 content: vec!["Continue where you left off".into()],
695 cache: true
696 }
697 ]
698 );
699
700 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
701 fake_model.end_last_completion_stream();
702 events.collect::<Vec<_>>().await;
703 thread.read_with(cx, |thread, _cx| {
704 assert_eq!(
705 thread.last_message().unwrap().to_markdown(),
706 indoc! {"
707 ## Assistant
708
709 Done
710 "}
711 )
712 });
713}
714
715#[gpui::test]
716async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
717 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
718 let fake_model = model.as_fake();
719
720 let events = thread
721 .update(cx, |thread, cx| {
722 thread.add_tool(EchoTool);
723 thread.send(UserMessageId::new(), ["abc"], cx)
724 })
725 .unwrap();
726 cx.run_until_parked();
727
728 let tool_use = LanguageModelToolUse {
729 id: "tool_id_1".into(),
730 name: EchoTool::name().into(),
731 raw_input: "{}".into(),
732 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
733 is_input_complete: true,
734 };
735 let tool_result = LanguageModelToolResult {
736 tool_use_id: "tool_id_1".into(),
737 tool_name: EchoTool::name().into(),
738 is_error: false,
739 content: "def".into(),
740 output: Some("def".into()),
741 };
742 fake_model
743 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
744 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
745 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
746 ));
747 fake_model.end_last_completion_stream();
748 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
749 assert!(
750 last_event
751 .unwrap_err()
752 .is::<language_model::ToolUseLimitReachedError>()
753 );
754
755 thread
756 .update(cx, |thread, cx| {
757 thread.send(UserMessageId::new(), vec!["ghi"], cx)
758 })
759 .unwrap();
760 cx.run_until_parked();
761 let completion = fake_model.pending_completions().pop().unwrap();
762 assert_eq!(
763 completion.messages[1..],
764 vec![
765 LanguageModelRequestMessage {
766 role: Role::User,
767 content: vec!["abc".into()],
768 cache: false
769 },
770 LanguageModelRequestMessage {
771 role: Role::Assistant,
772 content: vec![MessageContent::ToolUse(tool_use)],
773 cache: false
774 },
775 LanguageModelRequestMessage {
776 role: Role::User,
777 content: vec![MessageContent::ToolResult(tool_result)],
778 cache: false
779 },
780 LanguageModelRequestMessage {
781 role: Role::User,
782 content: vec!["ghi".into()],
783 cache: true
784 }
785 ]
786 );
787}
788
789async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
790 let event = events
791 .next()
792 .await
793 .expect("no tool call authorization event received")
794 .unwrap();
795 match event {
796 ThreadEvent::ToolCall(tool_call) => tool_call,
797 event => {
798 panic!("Unexpected event {event:?}");
799 }
800 }
801}
802
803async fn expect_tool_call_update_fields(
804 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
805) -> acp::ToolCallUpdate {
806 let event = events
807 .next()
808 .await
809 .expect("no tool call authorization event received")
810 .unwrap();
811 match event {
812 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
813 event => {
814 panic!("Unexpected event {event:?}");
815 }
816 }
817}
818
819async fn next_tool_call_authorization(
820 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
821) -> ToolCallAuthorization {
822 loop {
823 let event = events
824 .next()
825 .await
826 .expect("no tool call authorization event received")
827 .unwrap();
828 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
829 let permission_kinds = tool_call_authorization
830 .options
831 .iter()
832 .map(|o| o.kind)
833 .collect::<Vec<_>>();
834 assert_eq!(
835 permission_kinds,
836 vec![
837 acp::PermissionOptionKind::AllowAlways,
838 acp::PermissionOptionKind::AllowOnce,
839 acp::PermissionOptionKind::RejectOnce,
840 ]
841 );
842 return tool_call_authorization;
843 }
844 }
845}
846
847#[gpui::test]
848#[cfg_attr(not(feature = "e2e"), ignore)]
849async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
850 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
851
852 // Test concurrent tool calls with different delay times
853 let events = thread
854 .update(cx, |thread, cx| {
855 thread.add_tool(DelayTool);
856 thread.send(
857 UserMessageId::new(),
858 [
859 "Call the delay tool twice in the same message.",
860 "Once with 100ms. Once with 300ms.",
861 "When both timers are complete, describe the outputs.",
862 ],
863 cx,
864 )
865 })
866 .unwrap()
867 .collect()
868 .await;
869
870 let stop_reasons = stop_events(events);
871 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
872
873 thread.update(cx, |thread, _cx| {
874 let last_message = thread.last_message().unwrap();
875 let agent_message = last_message.as_agent_message().unwrap();
876 let text = agent_message
877 .content
878 .iter()
879 .filter_map(|content| {
880 if let AgentMessageContent::Text(text) = content {
881 Some(text.as_str())
882 } else {
883 None
884 }
885 })
886 .collect::<String>();
887
888 assert!(text.contains("Ding"));
889 });
890}
891
892#[gpui::test]
893async fn test_profiles(cx: &mut TestAppContext) {
894 let ThreadTest {
895 model, thread, fs, ..
896 } = setup(cx, TestModel::Fake).await;
897 let fake_model = model.as_fake();
898
899 thread.update(cx, |thread, _cx| {
900 thread.add_tool(DelayTool);
901 thread.add_tool(EchoTool);
902 thread.add_tool(InfiniteTool);
903 });
904
905 // Override profiles and wait for settings to be loaded.
906 fs.insert_file(
907 paths::settings_file(),
908 json!({
909 "agent": {
910 "profiles": {
911 "test-1": {
912 "name": "Test Profile 1",
913 "tools": {
914 EchoTool::name(): true,
915 DelayTool::name(): true,
916 }
917 },
918 "test-2": {
919 "name": "Test Profile 2",
920 "tools": {
921 InfiniteTool::name(): true,
922 }
923 }
924 }
925 }
926 })
927 .to_string()
928 .into_bytes(),
929 )
930 .await;
931 cx.run_until_parked();
932
933 // Test that test-1 profile (default) has echo and delay tools
934 thread
935 .update(cx, |thread, cx| {
936 thread.set_profile(AgentProfileId("test-1".into()));
937 thread.send(UserMessageId::new(), ["test"], cx)
938 })
939 .unwrap();
940 cx.run_until_parked();
941
942 let mut pending_completions = fake_model.pending_completions();
943 assert_eq!(pending_completions.len(), 1);
944 let completion = pending_completions.pop().unwrap();
945 let tool_names: Vec<String> = completion
946 .tools
947 .iter()
948 .map(|tool| tool.name.clone())
949 .collect();
950 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
951 fake_model.end_last_completion_stream();
952
953 // Switch to test-2 profile, and verify that it has only the infinite tool.
954 thread
955 .update(cx, |thread, cx| {
956 thread.set_profile(AgentProfileId("test-2".into()));
957 thread.send(UserMessageId::new(), ["test2"], cx)
958 })
959 .unwrap();
960 cx.run_until_parked();
961 let mut pending_completions = fake_model.pending_completions();
962 assert_eq!(pending_completions.len(), 1);
963 let completion = pending_completions.pop().unwrap();
964 let tool_names: Vec<String> = completion
965 .tools
966 .iter()
967 .map(|tool| tool.name.clone())
968 .collect();
969 assert_eq!(tool_names, vec![InfiniteTool::name()]);
970}
971
972#[gpui::test]
973async fn test_mcp_tools(cx: &mut TestAppContext) {
974 let ThreadTest {
975 model,
976 thread,
977 context_server_store,
978 fs,
979 ..
980 } = setup(cx, TestModel::Fake).await;
981 let fake_model = model.as_fake();
982
983 // Override profiles and wait for settings to be loaded.
984 fs.insert_file(
985 paths::settings_file(),
986 json!({
987 "agent": {
988 "always_allow_tool_actions": true,
989 "profiles": {
990 "test": {
991 "name": "Test Profile",
992 "enable_all_context_servers": true,
993 "tools": {
994 EchoTool::name(): true,
995 }
996 },
997 }
998 }
999 })
1000 .to_string()
1001 .into_bytes(),
1002 )
1003 .await;
1004 cx.run_until_parked();
1005 thread.update(cx, |thread, _| {
1006 thread.set_profile(AgentProfileId("test".into()))
1007 });
1008
1009 let mut mcp_tool_calls = setup_context_server(
1010 "test_server",
1011 vec![context_server::types::Tool {
1012 name: "echo".into(),
1013 description: None,
1014 input_schema: serde_json::to_value(EchoTool::input_schema(
1015 LanguageModelToolSchemaFormat::JsonSchema,
1016 ))
1017 .unwrap(),
1018 output_schema: None,
1019 annotations: None,
1020 }],
1021 &context_server_store,
1022 cx,
1023 );
1024
1025 let events = thread.update(cx, |thread, cx| {
1026 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1027 });
1028 cx.run_until_parked();
1029
1030 // Simulate the model calling the MCP tool.
1031 let completion = fake_model.pending_completions().pop().unwrap();
1032 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1033 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1034 LanguageModelToolUse {
1035 id: "tool_1".into(),
1036 name: "echo".into(),
1037 raw_input: json!({"text": "test"}).to_string(),
1038 input: json!({"text": "test"}),
1039 is_input_complete: true,
1040 },
1041 ));
1042 fake_model.end_last_completion_stream();
1043 cx.run_until_parked();
1044
1045 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1046 assert_eq!(tool_call_params.name, "echo");
1047 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1048 tool_call_response
1049 .send(context_server::types::CallToolResponse {
1050 content: vec![context_server::types::ToolResponseContent::Text {
1051 text: "test".into(),
1052 }],
1053 is_error: None,
1054 meta: None,
1055 structured_content: None,
1056 })
1057 .unwrap();
1058 cx.run_until_parked();
1059
1060 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1061 fake_model.send_last_completion_stream_text_chunk("Done!");
1062 fake_model.end_last_completion_stream();
1063 events.collect::<Vec<_>>().await;
1064
1065 // Send again after adding the echo tool, ensuring the name collision is resolved.
1066 let events = thread.update(cx, |thread, cx| {
1067 thread.add_tool(EchoTool);
1068 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1069 });
1070 cx.run_until_parked();
1071 let completion = fake_model.pending_completions().pop().unwrap();
1072 assert_eq!(
1073 tool_names_for_completion(&completion),
1074 vec!["echo", "test_server_echo"]
1075 );
1076 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1077 LanguageModelToolUse {
1078 id: "tool_2".into(),
1079 name: "test_server_echo".into(),
1080 raw_input: json!({"text": "mcp"}).to_string(),
1081 input: json!({"text": "mcp"}),
1082 is_input_complete: true,
1083 },
1084 ));
1085 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1086 LanguageModelToolUse {
1087 id: "tool_3".into(),
1088 name: "echo".into(),
1089 raw_input: json!({"text": "native"}).to_string(),
1090 input: json!({"text": "native"}),
1091 is_input_complete: true,
1092 },
1093 ));
1094 fake_model.end_last_completion_stream();
1095 cx.run_until_parked();
1096
1097 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1098 assert_eq!(tool_call_params.name, "echo");
1099 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1100 tool_call_response
1101 .send(context_server::types::CallToolResponse {
1102 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1103 is_error: None,
1104 meta: None,
1105 structured_content: None,
1106 })
1107 .unwrap();
1108 cx.run_until_parked();
1109
1110 // Ensure the tool results were inserted with the correct names.
1111 let completion = fake_model.pending_completions().pop().unwrap();
1112 assert_eq!(
1113 completion.messages.last().unwrap().content,
1114 vec![
1115 MessageContent::ToolResult(LanguageModelToolResult {
1116 tool_use_id: "tool_3".into(),
1117 tool_name: "echo".into(),
1118 is_error: false,
1119 content: "native".into(),
1120 output: Some("native".into()),
1121 },),
1122 MessageContent::ToolResult(LanguageModelToolResult {
1123 tool_use_id: "tool_2".into(),
1124 tool_name: "test_server_echo".into(),
1125 is_error: false,
1126 content: "mcp".into(),
1127 output: Some("mcp".into()),
1128 },),
1129 ]
1130 );
1131 fake_model.end_last_completion_stream();
1132 events.collect::<Vec<_>>().await;
1133}
1134
1135#[gpui::test]
1136async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1137 let ThreadTest {
1138 model,
1139 thread,
1140 context_server_store,
1141 fs,
1142 ..
1143 } = setup(cx, TestModel::Fake).await;
1144 let fake_model = model.as_fake();
1145
1146 // Set up a profile with all tools enabled
1147 fs.insert_file(
1148 paths::settings_file(),
1149 json!({
1150 "agent": {
1151 "profiles": {
1152 "test": {
1153 "name": "Test Profile",
1154 "enable_all_context_servers": true,
1155 "tools": {
1156 EchoTool::name(): true,
1157 DelayTool::name(): true,
1158 WordListTool::name(): true,
1159 ToolRequiringPermission::name(): true,
1160 InfiniteTool::name(): true,
1161 }
1162 },
1163 }
1164 }
1165 })
1166 .to_string()
1167 .into_bytes(),
1168 )
1169 .await;
1170 cx.run_until_parked();
1171
1172 thread.update(cx, |thread, _| {
1173 thread.set_profile(AgentProfileId("test".into()));
1174 thread.add_tool(EchoTool);
1175 thread.add_tool(DelayTool);
1176 thread.add_tool(WordListTool);
1177 thread.add_tool(ToolRequiringPermission);
1178 thread.add_tool(InfiniteTool);
1179 });
1180
1181 // Set up multiple context servers with some overlapping tool names
1182 let _server1_calls = setup_context_server(
1183 "xxx",
1184 vec![
1185 context_server::types::Tool {
1186 name: "echo".into(), // Conflicts with native EchoTool
1187 description: None,
1188 input_schema: serde_json::to_value(EchoTool::input_schema(
1189 LanguageModelToolSchemaFormat::JsonSchema,
1190 ))
1191 .unwrap(),
1192 output_schema: None,
1193 annotations: None,
1194 },
1195 context_server::types::Tool {
1196 name: "unique_tool_1".into(),
1197 description: None,
1198 input_schema: json!({"type": "object", "properties": {}}),
1199 output_schema: None,
1200 annotations: None,
1201 },
1202 ],
1203 &context_server_store,
1204 cx,
1205 );
1206
1207 let _server2_calls = setup_context_server(
1208 "yyy",
1209 vec![
1210 context_server::types::Tool {
1211 name: "echo".into(), // Also conflicts with native EchoTool
1212 description: None,
1213 input_schema: serde_json::to_value(EchoTool::input_schema(
1214 LanguageModelToolSchemaFormat::JsonSchema,
1215 ))
1216 .unwrap(),
1217 output_schema: None,
1218 annotations: None,
1219 },
1220 context_server::types::Tool {
1221 name: "unique_tool_2".into(),
1222 description: None,
1223 input_schema: json!({"type": "object", "properties": {}}),
1224 output_schema: None,
1225 annotations: None,
1226 },
1227 context_server::types::Tool {
1228 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1229 description: None,
1230 input_schema: json!({"type": "object", "properties": {}}),
1231 output_schema: None,
1232 annotations: None,
1233 },
1234 context_server::types::Tool {
1235 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1236 description: None,
1237 input_schema: json!({"type": "object", "properties": {}}),
1238 output_schema: None,
1239 annotations: None,
1240 },
1241 ],
1242 &context_server_store,
1243 cx,
1244 );
1245 let _server3_calls = setup_context_server(
1246 "zzz",
1247 vec![
1248 context_server::types::Tool {
1249 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1250 description: None,
1251 input_schema: json!({"type": "object", "properties": {}}),
1252 output_schema: None,
1253 annotations: None,
1254 },
1255 context_server::types::Tool {
1256 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1257 description: None,
1258 input_schema: json!({"type": "object", "properties": {}}),
1259 output_schema: None,
1260 annotations: None,
1261 },
1262 context_server::types::Tool {
1263 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1264 description: None,
1265 input_schema: json!({"type": "object", "properties": {}}),
1266 output_schema: None,
1267 annotations: None,
1268 },
1269 ],
1270 &context_server_store,
1271 cx,
1272 );
1273
1274 thread
1275 .update(cx, |thread, cx| {
1276 thread.send(UserMessageId::new(), ["Go"], cx)
1277 })
1278 .unwrap();
1279 cx.run_until_parked();
1280 let completion = fake_model.pending_completions().pop().unwrap();
1281 assert_eq!(
1282 tool_names_for_completion(&completion),
1283 vec![
1284 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1285 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1286 "delay",
1287 "echo",
1288 "infinite",
1289 "tool_requiring_permission",
1290 "unique_tool_1",
1291 "unique_tool_2",
1292 "word_list",
1293 "xxx_echo",
1294 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1295 "yyy_echo",
1296 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1297 ]
1298 );
1299}
1300
1301#[gpui::test]
1302#[cfg_attr(not(feature = "e2e"), ignore)]
1303async fn test_cancellation(cx: &mut TestAppContext) {
1304 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1305
1306 let mut events = thread
1307 .update(cx, |thread, cx| {
1308 thread.add_tool(InfiniteTool);
1309 thread.add_tool(EchoTool);
1310 thread.send(
1311 UserMessageId::new(),
1312 ["Call the echo tool, then call the infinite tool, then explain their output"],
1313 cx,
1314 )
1315 })
1316 .unwrap();
1317
1318 // Wait until both tools are called.
1319 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1320 let mut echo_id = None;
1321 let mut echo_completed = false;
1322 while let Some(event) = events.next().await {
1323 match event.unwrap() {
1324 ThreadEvent::ToolCall(tool_call) => {
1325 assert_eq!(tool_call.title, expected_tools.remove(0));
1326 if tool_call.title == "Echo" {
1327 echo_id = Some(tool_call.id);
1328 }
1329 }
1330 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1331 acp::ToolCallUpdate {
1332 id,
1333 fields:
1334 acp::ToolCallUpdateFields {
1335 status: Some(acp::ToolCallStatus::Completed),
1336 ..
1337 },
1338 meta: None,
1339 },
1340 )) if Some(&id) == echo_id.as_ref() => {
1341 echo_completed = true;
1342 }
1343 _ => {}
1344 }
1345
1346 if expected_tools.is_empty() && echo_completed {
1347 break;
1348 }
1349 }
1350
1351 // Cancel the current send and ensure that the event stream is closed, even
1352 // if one of the tools is still running.
1353 thread.update(cx, |thread, cx| thread.cancel(cx));
1354 let events = events.collect::<Vec<_>>().await;
1355 let last_event = events.last();
1356 assert!(
1357 matches!(
1358 last_event,
1359 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1360 ),
1361 "unexpected event {last_event:?}"
1362 );
1363
1364 // Ensure we can still send a new message after cancellation.
1365 let events = thread
1366 .update(cx, |thread, cx| {
1367 thread.send(
1368 UserMessageId::new(),
1369 ["Testing: reply with 'Hello' then stop."],
1370 cx,
1371 )
1372 })
1373 .unwrap()
1374 .collect::<Vec<_>>()
1375 .await;
1376 thread.update(cx, |thread, _cx| {
1377 let message = thread.last_message().unwrap();
1378 let agent_message = message.as_agent_message().unwrap();
1379 assert_eq!(
1380 agent_message.content,
1381 vec![AgentMessageContent::Text("Hello".to_string())]
1382 );
1383 });
1384 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1385}
1386
1387#[gpui::test]
1388async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1389 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1390 let fake_model = model.as_fake();
1391
1392 let events_1 = thread
1393 .update(cx, |thread, cx| {
1394 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1395 })
1396 .unwrap();
1397 cx.run_until_parked();
1398 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1399 cx.run_until_parked();
1400
1401 let events_2 = thread
1402 .update(cx, |thread, cx| {
1403 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1404 })
1405 .unwrap();
1406 cx.run_until_parked();
1407 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1408 fake_model
1409 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1410 fake_model.end_last_completion_stream();
1411
1412 let events_1 = events_1.collect::<Vec<_>>().await;
1413 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1414 let events_2 = events_2.collect::<Vec<_>>().await;
1415 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1416}
1417
1418#[gpui::test]
1419async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1420 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1421 let fake_model = model.as_fake();
1422
1423 let events_1 = thread
1424 .update(cx, |thread, cx| {
1425 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1426 })
1427 .unwrap();
1428 cx.run_until_parked();
1429 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1430 fake_model
1431 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1432 fake_model.end_last_completion_stream();
1433 let events_1 = events_1.collect::<Vec<_>>().await;
1434
1435 let events_2 = thread
1436 .update(cx, |thread, cx| {
1437 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1438 })
1439 .unwrap();
1440 cx.run_until_parked();
1441 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1442 fake_model
1443 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1444 fake_model.end_last_completion_stream();
1445 let events_2 = events_2.collect::<Vec<_>>().await;
1446
1447 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1448 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1449}
1450
1451#[gpui::test]
1452async fn test_refusal(cx: &mut TestAppContext) {
1453 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1454 let fake_model = model.as_fake();
1455
1456 let events = thread
1457 .update(cx, |thread, cx| {
1458 thread.send(UserMessageId::new(), ["Hello"], cx)
1459 })
1460 .unwrap();
1461 cx.run_until_parked();
1462 thread.read_with(cx, |thread, _| {
1463 assert_eq!(
1464 thread.to_markdown(),
1465 indoc! {"
1466 ## User
1467
1468 Hello
1469 "}
1470 );
1471 });
1472
1473 fake_model.send_last_completion_stream_text_chunk("Hey!");
1474 cx.run_until_parked();
1475 thread.read_with(cx, |thread, _| {
1476 assert_eq!(
1477 thread.to_markdown(),
1478 indoc! {"
1479 ## User
1480
1481 Hello
1482
1483 ## Assistant
1484
1485 Hey!
1486 "}
1487 );
1488 });
1489
1490 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1491 fake_model
1492 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1493 let events = events.collect::<Vec<_>>().await;
1494 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1495 thread.read_with(cx, |thread, _| {
1496 assert_eq!(thread.to_markdown(), "");
1497 });
1498}
1499
1500#[gpui::test]
1501async fn test_truncate_first_message(cx: &mut TestAppContext) {
1502 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1503 let fake_model = model.as_fake();
1504
1505 let message_id = UserMessageId::new();
1506 thread
1507 .update(cx, |thread, cx| {
1508 thread.send(message_id.clone(), ["Hello"], cx)
1509 })
1510 .unwrap();
1511 cx.run_until_parked();
1512 thread.read_with(cx, |thread, _| {
1513 assert_eq!(
1514 thread.to_markdown(),
1515 indoc! {"
1516 ## User
1517
1518 Hello
1519 "}
1520 );
1521 assert_eq!(thread.latest_token_usage(), None);
1522 });
1523
1524 fake_model.send_last_completion_stream_text_chunk("Hey!");
1525 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1526 language_model::TokenUsage {
1527 input_tokens: 32_000,
1528 output_tokens: 16_000,
1529 cache_creation_input_tokens: 0,
1530 cache_read_input_tokens: 0,
1531 },
1532 ));
1533 cx.run_until_parked();
1534 thread.read_with(cx, |thread, _| {
1535 assert_eq!(
1536 thread.to_markdown(),
1537 indoc! {"
1538 ## User
1539
1540 Hello
1541
1542 ## Assistant
1543
1544 Hey!
1545 "}
1546 );
1547 assert_eq!(
1548 thread.latest_token_usage(),
1549 Some(acp_thread::TokenUsage {
1550 used_tokens: 32_000 + 16_000,
1551 max_tokens: 1_000_000,
1552 })
1553 );
1554 });
1555
1556 thread
1557 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1558 .unwrap();
1559 cx.run_until_parked();
1560 thread.read_with(cx, |thread, _| {
1561 assert_eq!(thread.to_markdown(), "");
1562 assert_eq!(thread.latest_token_usage(), None);
1563 });
1564
1565 // Ensure we can still send a new message after truncation.
1566 thread
1567 .update(cx, |thread, cx| {
1568 thread.send(UserMessageId::new(), ["Hi"], cx)
1569 })
1570 .unwrap();
1571 thread.update(cx, |thread, _cx| {
1572 assert_eq!(
1573 thread.to_markdown(),
1574 indoc! {"
1575 ## User
1576
1577 Hi
1578 "}
1579 );
1580 });
1581 cx.run_until_parked();
1582 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1583 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1584 language_model::TokenUsage {
1585 input_tokens: 40_000,
1586 output_tokens: 20_000,
1587 cache_creation_input_tokens: 0,
1588 cache_read_input_tokens: 0,
1589 },
1590 ));
1591 cx.run_until_parked();
1592 thread.read_with(cx, |thread, _| {
1593 assert_eq!(
1594 thread.to_markdown(),
1595 indoc! {"
1596 ## User
1597
1598 Hi
1599
1600 ## Assistant
1601
1602 Ahoy!
1603 "}
1604 );
1605
1606 assert_eq!(
1607 thread.latest_token_usage(),
1608 Some(acp_thread::TokenUsage {
1609 used_tokens: 40_000 + 20_000,
1610 max_tokens: 1_000_000,
1611 })
1612 );
1613 });
1614}
1615
1616#[gpui::test]
1617async fn test_truncate_second_message(cx: &mut TestAppContext) {
1618 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1619 let fake_model = model.as_fake();
1620
1621 thread
1622 .update(cx, |thread, cx| {
1623 thread.send(UserMessageId::new(), ["Message 1"], cx)
1624 })
1625 .unwrap();
1626 cx.run_until_parked();
1627 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1628 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1629 language_model::TokenUsage {
1630 input_tokens: 32_000,
1631 output_tokens: 16_000,
1632 cache_creation_input_tokens: 0,
1633 cache_read_input_tokens: 0,
1634 },
1635 ));
1636 fake_model.end_last_completion_stream();
1637 cx.run_until_parked();
1638
1639 let assert_first_message_state = |cx: &mut TestAppContext| {
1640 thread.clone().read_with(cx, |thread, _| {
1641 assert_eq!(
1642 thread.to_markdown(),
1643 indoc! {"
1644 ## User
1645
1646 Message 1
1647
1648 ## Assistant
1649
1650 Message 1 response
1651 "}
1652 );
1653
1654 assert_eq!(
1655 thread.latest_token_usage(),
1656 Some(acp_thread::TokenUsage {
1657 used_tokens: 32_000 + 16_000,
1658 max_tokens: 1_000_000,
1659 })
1660 );
1661 });
1662 };
1663
1664 assert_first_message_state(cx);
1665
1666 let second_message_id = UserMessageId::new();
1667 thread
1668 .update(cx, |thread, cx| {
1669 thread.send(second_message_id.clone(), ["Message 2"], cx)
1670 })
1671 .unwrap();
1672 cx.run_until_parked();
1673
1674 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1675 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1676 language_model::TokenUsage {
1677 input_tokens: 40_000,
1678 output_tokens: 20_000,
1679 cache_creation_input_tokens: 0,
1680 cache_read_input_tokens: 0,
1681 },
1682 ));
1683 fake_model.end_last_completion_stream();
1684 cx.run_until_parked();
1685
1686 thread.read_with(cx, |thread, _| {
1687 assert_eq!(
1688 thread.to_markdown(),
1689 indoc! {"
1690 ## User
1691
1692 Message 1
1693
1694 ## Assistant
1695
1696 Message 1 response
1697
1698 ## User
1699
1700 Message 2
1701
1702 ## Assistant
1703
1704 Message 2 response
1705 "}
1706 );
1707
1708 assert_eq!(
1709 thread.latest_token_usage(),
1710 Some(acp_thread::TokenUsage {
1711 used_tokens: 40_000 + 20_000,
1712 max_tokens: 1_000_000,
1713 })
1714 );
1715 });
1716
1717 thread
1718 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1719 .unwrap();
1720 cx.run_until_parked();
1721
1722 assert_first_message_state(cx);
1723}
1724
1725#[gpui::test]
1726async fn test_title_generation(cx: &mut TestAppContext) {
1727 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1728 let fake_model = model.as_fake();
1729
1730 let summary_model = Arc::new(FakeLanguageModel::default());
1731 thread.update(cx, |thread, cx| {
1732 thread.set_summarization_model(Some(summary_model.clone()), cx)
1733 });
1734
1735 let send = thread
1736 .update(cx, |thread, cx| {
1737 thread.send(UserMessageId::new(), ["Hello"], cx)
1738 })
1739 .unwrap();
1740 cx.run_until_parked();
1741
1742 fake_model.send_last_completion_stream_text_chunk("Hey!");
1743 fake_model.end_last_completion_stream();
1744 cx.run_until_parked();
1745 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1746
1747 // Ensure the summary model has been invoked to generate a title.
1748 summary_model.send_last_completion_stream_text_chunk("Hello ");
1749 summary_model.send_last_completion_stream_text_chunk("world\nG");
1750 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1751 summary_model.end_last_completion_stream();
1752 send.collect::<Vec<_>>().await;
1753 cx.run_until_parked();
1754 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1755
1756 // Send another message, ensuring no title is generated this time.
1757 let send = thread
1758 .update(cx, |thread, cx| {
1759 thread.send(UserMessageId::new(), ["Hello again"], cx)
1760 })
1761 .unwrap();
1762 cx.run_until_parked();
1763 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1764 fake_model.end_last_completion_stream();
1765 cx.run_until_parked();
1766 assert_eq!(summary_model.pending_completions(), Vec::new());
1767 send.collect::<Vec<_>>().await;
1768 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1769}
1770
1771#[gpui::test]
1772async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1773 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1774 let fake_model = model.as_fake();
1775
1776 let _events = thread
1777 .update(cx, |thread, cx| {
1778 thread.add_tool(ToolRequiringPermission);
1779 thread.add_tool(EchoTool);
1780 thread.send(UserMessageId::new(), ["Hey!"], cx)
1781 })
1782 .unwrap();
1783 cx.run_until_parked();
1784
1785 let permission_tool_use = LanguageModelToolUse {
1786 id: "tool_id_1".into(),
1787 name: ToolRequiringPermission::name().into(),
1788 raw_input: "{}".into(),
1789 input: json!({}),
1790 is_input_complete: true,
1791 };
1792 let echo_tool_use = LanguageModelToolUse {
1793 id: "tool_id_2".into(),
1794 name: EchoTool::name().into(),
1795 raw_input: json!({"text": "test"}).to_string(),
1796 input: json!({"text": "test"}),
1797 is_input_complete: true,
1798 };
1799 fake_model.send_last_completion_stream_text_chunk("Hi!");
1800 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1801 permission_tool_use,
1802 ));
1803 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1804 echo_tool_use.clone(),
1805 ));
1806 fake_model.end_last_completion_stream();
1807 cx.run_until_parked();
1808
1809 // Ensure pending tools are skipped when building a request.
1810 let request = thread
1811 .read_with(cx, |thread, cx| {
1812 thread.build_completion_request(CompletionIntent::EditFile, cx)
1813 })
1814 .unwrap();
1815 assert_eq!(
1816 request.messages[1..],
1817 vec![
1818 LanguageModelRequestMessage {
1819 role: Role::User,
1820 content: vec!["Hey!".into()],
1821 cache: true
1822 },
1823 LanguageModelRequestMessage {
1824 role: Role::Assistant,
1825 content: vec![
1826 MessageContent::Text("Hi!".into()),
1827 MessageContent::ToolUse(echo_tool_use.clone())
1828 ],
1829 cache: false
1830 },
1831 LanguageModelRequestMessage {
1832 role: Role::User,
1833 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1834 tool_use_id: echo_tool_use.id.clone(),
1835 tool_name: echo_tool_use.name,
1836 is_error: false,
1837 content: "test".into(),
1838 output: Some("test".into())
1839 })],
1840 cache: false
1841 },
1842 ],
1843 );
1844}
1845
1846#[gpui::test]
1847async fn test_agent_connection(cx: &mut TestAppContext) {
1848 cx.update(settings::init);
1849 let templates = Templates::new();
1850
1851 // Initialize language model system with test provider
1852 cx.update(|cx| {
1853 gpui_tokio::init(cx);
1854 client::init_settings(cx);
1855
1856 let http_client = FakeHttpClient::with_404_response();
1857 let clock = Arc::new(clock::FakeSystemClock::new());
1858 let client = Client::new(clock, http_client, cx);
1859 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1860 language_model::init(client.clone(), cx);
1861 language_models::init(user_store, client.clone(), cx);
1862 Project::init_settings(cx);
1863 LanguageModelRegistry::test(cx);
1864 agent_settings::init(cx);
1865 });
1866 cx.executor().forbid_parking();
1867
1868 // Create a project for new_thread
1869 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1870 fake_fs.insert_tree(path!("/test"), json!({})).await;
1871 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1872 let cwd = Path::new("/test");
1873 let text_thread_store =
1874 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1875 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1876
1877 // Create agent and connection
1878 let agent = NativeAgent::new(
1879 project.clone(),
1880 history_store,
1881 templates.clone(),
1882 None,
1883 fake_fs.clone(),
1884 &mut cx.to_async(),
1885 )
1886 .await
1887 .unwrap();
1888 let connection = NativeAgentConnection(agent.clone());
1889
1890 // Create a thread using new_thread
1891 let connection_rc = Rc::new(connection.clone());
1892 let acp_thread = cx
1893 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1894 .await
1895 .expect("new_thread should succeed");
1896
1897 // Get the session_id from the AcpThread
1898 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1899
1900 // Test model_selector returns Some
1901 let selector_opt = connection.model_selector(&session_id);
1902 assert!(
1903 selector_opt.is_some(),
1904 "agent should always support ModelSelector"
1905 );
1906 let selector = selector_opt.unwrap();
1907
1908 // Test list_models
1909 let listed_models = cx
1910 .update(|cx| selector.list_models(cx))
1911 .await
1912 .expect("list_models should succeed");
1913 let AgentModelList::Grouped(listed_models) = listed_models else {
1914 panic!("Unexpected model list type");
1915 };
1916 assert!(!listed_models.is_empty(), "should have at least one model");
1917 assert_eq!(
1918 listed_models[&AgentModelGroupName("Fake".into())][0]
1919 .id
1920 .0
1921 .as_ref(),
1922 "fake/fake"
1923 );
1924
1925 // Test selected_model returns the default
1926 let model = cx
1927 .update(|cx| selector.selected_model(cx))
1928 .await
1929 .expect("selected_model should succeed");
1930 let model = cx
1931 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1932 .unwrap();
1933 let model = model.as_fake();
1934 assert_eq!(model.id().0, "fake", "should return default model");
1935
1936 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1937 cx.run_until_parked();
1938 model.send_last_completion_stream_text_chunk("def");
1939 cx.run_until_parked();
1940 acp_thread.read_with(cx, |thread, cx| {
1941 assert_eq!(
1942 thread.to_markdown(cx),
1943 indoc! {"
1944 ## User
1945
1946 abc
1947
1948 ## Assistant
1949
1950 def
1951
1952 "}
1953 )
1954 });
1955
1956 // Test cancel
1957 cx.update(|cx| connection.cancel(&session_id, cx));
1958 request.await.expect("prompt should fail gracefully");
1959
1960 // Ensure that dropping the ACP thread causes the native thread to be
1961 // dropped as well.
1962 cx.update(|_| drop(acp_thread));
1963 let result = cx
1964 .update(|cx| {
1965 connection.prompt(
1966 Some(acp_thread::UserMessageId::new()),
1967 acp::PromptRequest {
1968 session_id: session_id.clone(),
1969 prompt: vec!["ghi".into()],
1970 meta: None,
1971 },
1972 cx,
1973 )
1974 })
1975 .await;
1976 assert_eq!(
1977 result.as_ref().unwrap_err().to_string(),
1978 "Session not found",
1979 "unexpected result: {:?}",
1980 result
1981 );
1982}
1983
1984#[gpui::test]
1985async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1986 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1987 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1988 let fake_model = model.as_fake();
1989
1990 let mut events = thread
1991 .update(cx, |thread, cx| {
1992 thread.send(UserMessageId::new(), ["Think"], cx)
1993 })
1994 .unwrap();
1995 cx.run_until_parked();
1996
1997 // Simulate streaming partial input.
1998 let input = json!({});
1999 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2000 LanguageModelToolUse {
2001 id: "1".into(),
2002 name: ThinkingTool::name().into(),
2003 raw_input: input.to_string(),
2004 input,
2005 is_input_complete: false,
2006 },
2007 ));
2008
2009 // Input streaming completed
2010 let input = json!({ "content": "Thinking hard!" });
2011 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2012 LanguageModelToolUse {
2013 id: "1".into(),
2014 name: "thinking".into(),
2015 raw_input: input.to_string(),
2016 input,
2017 is_input_complete: true,
2018 },
2019 ));
2020 fake_model.end_last_completion_stream();
2021 cx.run_until_parked();
2022
2023 let tool_call = expect_tool_call(&mut events).await;
2024 assert_eq!(
2025 tool_call,
2026 acp::ToolCall {
2027 id: acp::ToolCallId("1".into()),
2028 title: "Thinking".into(),
2029 kind: acp::ToolKind::Think,
2030 status: acp::ToolCallStatus::Pending,
2031 content: vec![],
2032 locations: vec![],
2033 raw_input: Some(json!({})),
2034 raw_output: None,
2035 meta: Some(json!({ "tool_name": "thinking" })),
2036 }
2037 );
2038 let update = expect_tool_call_update_fields(&mut events).await;
2039 assert_eq!(
2040 update,
2041 acp::ToolCallUpdate {
2042 id: acp::ToolCallId("1".into()),
2043 fields: acp::ToolCallUpdateFields {
2044 title: Some("Thinking".into()),
2045 kind: Some(acp::ToolKind::Think),
2046 raw_input: Some(json!({ "content": "Thinking hard!" })),
2047 ..Default::default()
2048 },
2049 meta: None,
2050 }
2051 );
2052 let update = expect_tool_call_update_fields(&mut events).await;
2053 assert_eq!(
2054 update,
2055 acp::ToolCallUpdate {
2056 id: acp::ToolCallId("1".into()),
2057 fields: acp::ToolCallUpdateFields {
2058 status: Some(acp::ToolCallStatus::InProgress),
2059 ..Default::default()
2060 },
2061 meta: None,
2062 }
2063 );
2064 let update = expect_tool_call_update_fields(&mut events).await;
2065 assert_eq!(
2066 update,
2067 acp::ToolCallUpdate {
2068 id: acp::ToolCallId("1".into()),
2069 fields: acp::ToolCallUpdateFields {
2070 content: Some(vec!["Thinking hard!".into()]),
2071 ..Default::default()
2072 },
2073 meta: None,
2074 }
2075 );
2076 let update = expect_tool_call_update_fields(&mut events).await;
2077 assert_eq!(
2078 update,
2079 acp::ToolCallUpdate {
2080 id: acp::ToolCallId("1".into()),
2081 fields: acp::ToolCallUpdateFields {
2082 status: Some(acp::ToolCallStatus::Completed),
2083 raw_output: Some("Finished thinking.".into()),
2084 ..Default::default()
2085 },
2086 meta: None,
2087 }
2088 );
2089}
2090
2091#[gpui::test]
2092async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2093 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2094 let fake_model = model.as_fake();
2095
2096 let mut events = thread
2097 .update(cx, |thread, cx| {
2098 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2099 thread.send(UserMessageId::new(), ["Hello!"], cx)
2100 })
2101 .unwrap();
2102 cx.run_until_parked();
2103
2104 fake_model.send_last_completion_stream_text_chunk("Hey!");
2105 fake_model.end_last_completion_stream();
2106
2107 let mut retry_events = Vec::new();
2108 while let Some(Ok(event)) = events.next().await {
2109 match event {
2110 ThreadEvent::Retry(retry_status) => {
2111 retry_events.push(retry_status);
2112 }
2113 ThreadEvent::Stop(..) => break,
2114 _ => {}
2115 }
2116 }
2117
2118 assert_eq!(retry_events.len(), 0);
2119 thread.read_with(cx, |thread, _cx| {
2120 assert_eq!(
2121 thread.to_markdown(),
2122 indoc! {"
2123 ## User
2124
2125 Hello!
2126
2127 ## Assistant
2128
2129 Hey!
2130 "}
2131 )
2132 });
2133}
2134
2135#[gpui::test]
2136async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2137 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2138 let fake_model = model.as_fake();
2139
2140 let mut events = thread
2141 .update(cx, |thread, cx| {
2142 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2143 thread.send(UserMessageId::new(), ["Hello!"], cx)
2144 })
2145 .unwrap();
2146 cx.run_until_parked();
2147
2148 fake_model.send_last_completion_stream_text_chunk("Hey,");
2149 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2150 provider: LanguageModelProviderName::new("Anthropic"),
2151 retry_after: Some(Duration::from_secs(3)),
2152 });
2153 fake_model.end_last_completion_stream();
2154
2155 cx.executor().advance_clock(Duration::from_secs(3));
2156 cx.run_until_parked();
2157
2158 fake_model.send_last_completion_stream_text_chunk("there!");
2159 fake_model.end_last_completion_stream();
2160 cx.run_until_parked();
2161
2162 let mut retry_events = Vec::new();
2163 while let Some(Ok(event)) = events.next().await {
2164 match event {
2165 ThreadEvent::Retry(retry_status) => {
2166 retry_events.push(retry_status);
2167 }
2168 ThreadEvent::Stop(..) => break,
2169 _ => {}
2170 }
2171 }
2172
2173 assert_eq!(retry_events.len(), 1);
2174 assert!(matches!(
2175 retry_events[0],
2176 acp_thread::RetryStatus { attempt: 1, .. }
2177 ));
2178 thread.read_with(cx, |thread, _cx| {
2179 assert_eq!(
2180 thread.to_markdown(),
2181 indoc! {"
2182 ## User
2183
2184 Hello!
2185
2186 ## Assistant
2187
2188 Hey,
2189
2190 [resume]
2191
2192 ## Assistant
2193
2194 there!
2195 "}
2196 )
2197 });
2198}
2199
2200#[gpui::test]
2201async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2202 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2203 let fake_model = model.as_fake();
2204
2205 let events = thread
2206 .update(cx, |thread, cx| {
2207 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2208 thread.add_tool(EchoTool);
2209 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2210 })
2211 .unwrap();
2212 cx.run_until_parked();
2213
2214 let tool_use_1 = LanguageModelToolUse {
2215 id: "tool_1".into(),
2216 name: EchoTool::name().into(),
2217 raw_input: json!({"text": "test"}).to_string(),
2218 input: json!({"text": "test"}),
2219 is_input_complete: true,
2220 };
2221 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2222 tool_use_1.clone(),
2223 ));
2224 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2225 provider: LanguageModelProviderName::new("Anthropic"),
2226 retry_after: Some(Duration::from_secs(3)),
2227 });
2228 fake_model.end_last_completion_stream();
2229
2230 cx.executor().advance_clock(Duration::from_secs(3));
2231 let completion = fake_model.pending_completions().pop().unwrap();
2232 assert_eq!(
2233 completion.messages[1..],
2234 vec![
2235 LanguageModelRequestMessage {
2236 role: Role::User,
2237 content: vec!["Call the echo tool!".into()],
2238 cache: false
2239 },
2240 LanguageModelRequestMessage {
2241 role: Role::Assistant,
2242 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2243 cache: false
2244 },
2245 LanguageModelRequestMessage {
2246 role: Role::User,
2247 content: vec![language_model::MessageContent::ToolResult(
2248 LanguageModelToolResult {
2249 tool_use_id: tool_use_1.id.clone(),
2250 tool_name: tool_use_1.name.clone(),
2251 is_error: false,
2252 content: "test".into(),
2253 output: Some("test".into())
2254 }
2255 )],
2256 cache: true
2257 },
2258 ]
2259 );
2260
2261 fake_model.send_last_completion_stream_text_chunk("Done");
2262 fake_model.end_last_completion_stream();
2263 cx.run_until_parked();
2264 events.collect::<Vec<_>>().await;
2265 thread.read_with(cx, |thread, _cx| {
2266 assert_eq!(
2267 thread.last_message(),
2268 Some(Message::Agent(AgentMessage {
2269 content: vec![AgentMessageContent::Text("Done".into())],
2270 tool_results: IndexMap::default()
2271 }))
2272 );
2273 })
2274}
2275
2276#[gpui::test]
2277async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2278 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2279 let fake_model = model.as_fake();
2280
2281 let mut events = thread
2282 .update(cx, |thread, cx| {
2283 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2284 thread.send(UserMessageId::new(), ["Hello!"], cx)
2285 })
2286 .unwrap();
2287 cx.run_until_parked();
2288
2289 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2290 fake_model.send_last_completion_stream_error(
2291 LanguageModelCompletionError::ServerOverloaded {
2292 provider: LanguageModelProviderName::new("Anthropic"),
2293 retry_after: Some(Duration::from_secs(3)),
2294 },
2295 );
2296 fake_model.end_last_completion_stream();
2297 cx.executor().advance_clock(Duration::from_secs(3));
2298 cx.run_until_parked();
2299 }
2300
2301 let mut errors = Vec::new();
2302 let mut retry_events = Vec::new();
2303 while let Some(event) = events.next().await {
2304 match event {
2305 Ok(ThreadEvent::Retry(retry_status)) => {
2306 retry_events.push(retry_status);
2307 }
2308 Ok(ThreadEvent::Stop(..)) => break,
2309 Err(error) => errors.push(error),
2310 _ => {}
2311 }
2312 }
2313
2314 assert_eq!(
2315 retry_events.len(),
2316 crate::thread::MAX_RETRY_ATTEMPTS as usize
2317 );
2318 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2319 assert_eq!(retry_events[i].attempt, i + 1);
2320 }
2321 assert_eq!(errors.len(), 1);
2322 let error = errors[0]
2323 .downcast_ref::<LanguageModelCompletionError>()
2324 .unwrap();
2325 assert!(matches!(
2326 error,
2327 LanguageModelCompletionError::ServerOverloaded { .. }
2328 ));
2329}
2330
2331/// Filters out the stop events for asserting against in tests
2332fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2333 result_events
2334 .into_iter()
2335 .filter_map(|event| match event.unwrap() {
2336 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2337 _ => None,
2338 })
2339 .collect()
2340}
2341
2342struct ThreadTest {
2343 model: Arc<dyn LanguageModel>,
2344 thread: Entity<Thread>,
2345 project_context: Entity<ProjectContext>,
2346 context_server_store: Entity<ContextServerStore>,
2347 fs: Arc<FakeFs>,
2348}
2349
2350enum TestModel {
2351 Sonnet4,
2352 Fake,
2353}
2354
2355impl TestModel {
2356 fn id(&self) -> LanguageModelId {
2357 match self {
2358 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2359 TestModel::Fake => unreachable!(),
2360 }
2361 }
2362}
2363
2364async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2365 cx.executor().allow_parking();
2366
2367 let fs = FakeFs::new(cx.background_executor.clone());
2368 fs.create_dir(paths::settings_file().parent().unwrap())
2369 .await
2370 .unwrap();
2371 fs.insert_file(
2372 paths::settings_file(),
2373 json!({
2374 "agent": {
2375 "default_profile": "test-profile",
2376 "profiles": {
2377 "test-profile": {
2378 "name": "Test Profile",
2379 "tools": {
2380 EchoTool::name(): true,
2381 DelayTool::name(): true,
2382 WordListTool::name(): true,
2383 ToolRequiringPermission::name(): true,
2384 InfiniteTool::name(): true,
2385 ThinkingTool::name(): true,
2386 }
2387 }
2388 }
2389 }
2390 })
2391 .to_string()
2392 .into_bytes(),
2393 )
2394 .await;
2395
2396 cx.update(|cx| {
2397 settings::init(cx);
2398 Project::init_settings(cx);
2399 agent_settings::init(cx);
2400
2401 match model {
2402 TestModel::Fake => {}
2403 TestModel::Sonnet4 => {
2404 gpui_tokio::init(cx);
2405 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2406 cx.set_http_client(Arc::new(http_client));
2407 client::init_settings(cx);
2408 let client = Client::production(cx);
2409 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2410 language_model::init(client.clone(), cx);
2411 language_models::init(user_store, client.clone(), cx);
2412 }
2413 };
2414
2415 watch_settings(fs.clone(), cx);
2416 });
2417
2418 let templates = Templates::new();
2419
2420 fs.insert_tree(path!("/test"), json!({})).await;
2421 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2422
2423 let model = cx
2424 .update(|cx| {
2425 if let TestModel::Fake = model {
2426 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2427 } else {
2428 let model_id = model.id();
2429 let models = LanguageModelRegistry::read_global(cx);
2430 let model = models
2431 .available_models(cx)
2432 .find(|model| model.id() == model_id)
2433 .unwrap();
2434
2435 let provider = models.provider(&model.provider_id()).unwrap();
2436 let authenticated = provider.authenticate(cx);
2437
2438 cx.spawn(async move |_cx| {
2439 authenticated.await.unwrap();
2440 model
2441 })
2442 }
2443 })
2444 .await;
2445
2446 let project_context = cx.new(|_cx| ProjectContext::default());
2447 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2448 let context_server_registry =
2449 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2450 let thread = cx.new(|cx| {
2451 Thread::new(
2452 project,
2453 project_context.clone(),
2454 context_server_registry,
2455 templates,
2456 Some(model.clone()),
2457 cx,
2458 )
2459 });
2460 ThreadTest {
2461 model,
2462 thread,
2463 project_context,
2464 context_server_store,
2465 fs,
2466 }
2467}
2468
2469#[cfg(test)]
2470#[ctor::ctor]
2471fn init_logger() {
2472 if std::env::var("RUST_LOG").is_ok() {
2473 env_logger::init();
2474 }
2475}
2476
2477fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2478 let fs = fs.clone();
2479 cx.spawn({
2480 async move |cx| {
2481 let mut new_settings_content_rx = settings::watch_config_file(
2482 cx.background_executor(),
2483 fs,
2484 paths::settings_file().clone(),
2485 );
2486
2487 while let Some(new_settings_content) = new_settings_content_rx.next().await {
2488 cx.update(|cx| {
2489 SettingsStore::update_global(cx, |settings, cx| {
2490 settings.set_user_settings(&new_settings_content, cx)
2491 })
2492 })
2493 .ok();
2494 }
2495 }
2496 })
2497 .detach();
2498}
2499
2500fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2501 completion
2502 .tools
2503 .iter()
2504 .map(|tool| tool.name.clone())
2505 .collect()
2506}
2507
2508fn setup_context_server(
2509 name: &'static str,
2510 tools: Vec<context_server::types::Tool>,
2511 context_server_store: &Entity<ContextServerStore>,
2512 cx: &mut TestAppContext,
2513) -> mpsc::UnboundedReceiver<(
2514 context_server::types::CallToolParams,
2515 oneshot::Sender<context_server::types::CallToolResponse>,
2516)> {
2517 cx.update(|cx| {
2518 let mut settings = ProjectSettings::get_global(cx).clone();
2519 settings.context_servers.insert(
2520 name.into(),
2521 project::project_settings::ContextServerSettings::Custom {
2522 enabled: true,
2523 command: ContextServerCommand {
2524 path: "somebinary".into(),
2525 args: Vec::new(),
2526 env: None,
2527 timeout: None,
2528 },
2529 },
2530 );
2531 ProjectSettings::override_global(settings, cx);
2532 });
2533
2534 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2535 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2536 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2537 context_server::types::InitializeResponse {
2538 protocol_version: context_server::types::ProtocolVersion(
2539 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2540 ),
2541 server_info: context_server::types::Implementation {
2542 name: name.into(),
2543 version: "1.0.0".to_string(),
2544 },
2545 capabilities: context_server::types::ServerCapabilities {
2546 tools: Some(context_server::types::ToolsCapabilities {
2547 list_changed: Some(true),
2548 }),
2549 ..Default::default()
2550 },
2551 meta: None,
2552 }
2553 })
2554 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2555 let tools = tools.clone();
2556 async move {
2557 context_server::types::ListToolsResponse {
2558 tools,
2559 next_cursor: None,
2560 meta: None,
2561 }
2562 }
2563 })
2564 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2565 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2566 async move {
2567 let (response_tx, response_rx) = oneshot::channel();
2568 mcp_tool_calls_tx
2569 .unbounded_send((params, response_tx))
2570 .unwrap();
2571 response_rx.await.unwrap()
2572 }
2573 });
2574 context_server_store.update(cx, |store, cx| {
2575 store.start_server(
2576 Arc::new(ContextServer::new(
2577 ContextServerId(name.into()),
2578 Arc::new(fake_transport),
2579 )),
2580 cx,
2581 );
2582 });
2583 cx.run_until_parked();
2584 mcp_tool_calls_rx
2585}