1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use collections::IndexMap;
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use fs::{FakeFs, Fs};
11use futures::{
12 StreamExt,
13 channel::{
14 mpsc::{self, UnboundedReceiver},
15 oneshot,
16 },
17};
18use gpui::{
19 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
20};
21use indoc::indoc;
22use language_model::{
23 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
24 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
25 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
26 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
27};
28use pretty_assertions::assert_eq;
29use project::{
30 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
31};
32use prompt_store::{ProjectContext, UserPromptId, UserRulesContext};
33use reqwest_client::ReqwestClient;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use serde_json::json;
37use settings::{Settings, SettingsStore};
38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
39
40use util::path;
41
42mod test_tools;
43use test_tools::*;
44
45#[gpui::test]
46async fn test_echo(cx: &mut TestAppContext) {
47 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
48 let fake_model = model.as_fake();
49
50 let events = thread
51 .update(cx, |thread, cx| {
52 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
53 })
54 .unwrap();
55 cx.run_until_parked();
56 fake_model.send_last_completion_stream_text_chunk("Hello");
57 fake_model
58 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
59 fake_model.end_last_completion_stream();
60
61 let events = events.collect().await;
62 thread.update(cx, |thread, _cx| {
63 assert_eq!(
64 thread.last_message().unwrap().to_markdown(),
65 indoc! {"
66 ## Assistant
67
68 Hello
69 "}
70 )
71 });
72 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
73}
74
75#[gpui::test]
76async fn test_thinking(cx: &mut TestAppContext) {
77 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
78 let fake_model = model.as_fake();
79
80 let events = thread
81 .update(cx, |thread, cx| {
82 thread.send(
83 UserMessageId::new(),
84 [indoc! {"
85 Testing:
86
87 Generate a thinking step where you just think the word 'Think',
88 and have your final answer be 'Hello'
89 "}],
90 cx,
91 )
92 })
93 .unwrap();
94 cx.run_until_parked();
95 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
96 text: "Think".to_string(),
97 signature: None,
98 });
99 fake_model.send_last_completion_stream_text_chunk("Hello");
100 fake_model
101 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
102 fake_model.end_last_completion_stream();
103
104 let events = events.collect().await;
105 thread.update(cx, |thread, _cx| {
106 assert_eq!(
107 thread.last_message().unwrap().to_markdown(),
108 indoc! {"
109 ## Assistant
110
111 <think>Think</think>
112 Hello
113 "}
114 )
115 });
116 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
117}
118
119#[gpui::test]
120async fn test_system_prompt(cx: &mut TestAppContext) {
121 let ThreadTest {
122 model,
123 thread,
124 project_context,
125 ..
126 } = setup(cx, TestModel::Fake).await;
127 let fake_model = model.as_fake();
128
129 project_context.update(cx, |project_context, _cx| {
130 project_context.shell = "test-shell".into()
131 });
132 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
133 thread
134 .update(cx, |thread, cx| {
135 thread.send(UserMessageId::new(), ["abc"], cx)
136 })
137 .unwrap();
138 cx.run_until_parked();
139 let mut pending_completions = fake_model.pending_completions();
140 assert_eq!(
141 pending_completions.len(),
142 1,
143 "unexpected pending completions: {:?}",
144 pending_completions
145 );
146
147 let pending_completion = pending_completions.pop().unwrap();
148 assert_eq!(pending_completion.messages[0].role, Role::System);
149
150 let system_message = &pending_completion.messages[0];
151 let system_prompt = system_message.content[0].to_str().unwrap();
152 assert!(
153 system_prompt.contains("test-shell"),
154 "unexpected system message: {:?}",
155 system_message
156 );
157 assert!(
158 system_prompt.contains("## Fixing Diagnostics"),
159 "unexpected system message: {:?}",
160 system_message
161 );
162}
163
164#[gpui::test]
165async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
166 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
167 let fake_model = model.as_fake();
168
169 thread
170 .update(cx, |thread, cx| {
171 thread.send(UserMessageId::new(), ["abc"], cx)
172 })
173 .unwrap();
174 cx.run_until_parked();
175 let mut pending_completions = fake_model.pending_completions();
176 assert_eq!(
177 pending_completions.len(),
178 1,
179 "unexpected pending completions: {:?}",
180 pending_completions
181 );
182
183 let pending_completion = pending_completions.pop().unwrap();
184 assert_eq!(pending_completion.messages[0].role, Role::System);
185
186 let system_message = &pending_completion.messages[0];
187 let system_prompt = system_message.content[0].to_str().unwrap();
188 assert!(
189 !system_prompt.contains("## Tool Use"),
190 "unexpected system message: {:?}",
191 system_message
192 );
193 assert!(
194 !system_prompt.contains("## Fixing Diagnostics"),
195 "unexpected system message: {:?}",
196 system_message
197 );
198}
199
200#[gpui::test]
201async fn test_prompt_caching(cx: &mut TestAppContext) {
202 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
203 let fake_model = model.as_fake();
204
205 // Send initial user message and verify it's cached
206 thread
207 .update(cx, |thread, cx| {
208 thread.send(UserMessageId::new(), ["Message 1"], cx)
209 })
210 .unwrap();
211 cx.run_until_parked();
212
213 let completion = fake_model.pending_completions().pop().unwrap();
214 assert_eq!(
215 completion.messages[1..],
216 vec![LanguageModelRequestMessage {
217 role: Role::User,
218 content: vec!["Message 1".into()],
219 cache: true
220 }]
221 );
222 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
223 "Response to Message 1".into(),
224 ));
225 fake_model.end_last_completion_stream();
226 cx.run_until_parked();
227
228 // Send another user message and verify only the latest is cached
229 thread
230 .update(cx, |thread, cx| {
231 thread.send(UserMessageId::new(), ["Message 2"], cx)
232 })
233 .unwrap();
234 cx.run_until_parked();
235
236 let completion = fake_model.pending_completions().pop().unwrap();
237 assert_eq!(
238 completion.messages[1..],
239 vec![
240 LanguageModelRequestMessage {
241 role: Role::User,
242 content: vec!["Message 1".into()],
243 cache: false
244 },
245 LanguageModelRequestMessage {
246 role: Role::Assistant,
247 content: vec!["Response to Message 1".into()],
248 cache: false
249 },
250 LanguageModelRequestMessage {
251 role: Role::User,
252 content: vec!["Message 2".into()],
253 cache: true
254 }
255 ]
256 );
257 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
258 "Response to Message 2".into(),
259 ));
260 fake_model.end_last_completion_stream();
261 cx.run_until_parked();
262
263 // Simulate a tool call and verify that the latest tool result is cached
264 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
265 thread
266 .update(cx, |thread, cx| {
267 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
268 })
269 .unwrap();
270 cx.run_until_parked();
271
272 let tool_use = LanguageModelToolUse {
273 id: "tool_1".into(),
274 name: EchoTool::name().into(),
275 raw_input: json!({"text": "test"}).to_string(),
276 input: json!({"text": "test"}),
277 is_input_complete: true,
278 };
279 fake_model
280 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
281 fake_model.end_last_completion_stream();
282 cx.run_until_parked();
283
284 let completion = fake_model.pending_completions().pop().unwrap();
285 let tool_result = LanguageModelToolResult {
286 tool_use_id: "tool_1".into(),
287 tool_name: EchoTool::name().into(),
288 is_error: false,
289 content: "test".into(),
290 output: Some("test".into()),
291 };
292 assert_eq!(
293 completion.messages[1..],
294 vec![
295 LanguageModelRequestMessage {
296 role: Role::User,
297 content: vec!["Message 1".into()],
298 cache: false
299 },
300 LanguageModelRequestMessage {
301 role: Role::Assistant,
302 content: vec!["Response to Message 1".into()],
303 cache: false
304 },
305 LanguageModelRequestMessage {
306 role: Role::User,
307 content: vec!["Message 2".into()],
308 cache: false
309 },
310 LanguageModelRequestMessage {
311 role: Role::Assistant,
312 content: vec!["Response to Message 2".into()],
313 cache: false
314 },
315 LanguageModelRequestMessage {
316 role: Role::User,
317 content: vec!["Use the echo tool".into()],
318 cache: false
319 },
320 LanguageModelRequestMessage {
321 role: Role::Assistant,
322 content: vec![MessageContent::ToolUse(tool_use)],
323 cache: false
324 },
325 LanguageModelRequestMessage {
326 role: Role::User,
327 content: vec![MessageContent::ToolResult(tool_result)],
328 cache: true
329 }
330 ]
331 );
332}
333
334#[gpui::test]
335#[cfg_attr(not(feature = "e2e"), ignore)]
336async fn test_basic_tool_calls(cx: &mut TestAppContext) {
337 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
338
339 // Test a tool call that's likely to complete *before* streaming stops.
340 let events = thread
341 .update(cx, |thread, cx| {
342 thread.add_tool(EchoTool);
343 thread.send(
344 UserMessageId::new(),
345 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
346 cx,
347 )
348 })
349 .unwrap()
350 .collect()
351 .await;
352 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
353
354 // Test a tool calls that's likely to complete *after* streaming stops.
355 let events = thread
356 .update(cx, |thread, cx| {
357 thread.remove_tool(&EchoTool::name());
358 thread.add_tool(DelayTool);
359 thread.send(
360 UserMessageId::new(),
361 [
362 "Now call the delay tool with 200ms.",
363 "When the timer goes off, then you echo the output of the tool.",
364 ],
365 cx,
366 )
367 })
368 .unwrap()
369 .collect()
370 .await;
371 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
372 thread.update(cx, |thread, _cx| {
373 assert!(
374 thread
375 .last_message()
376 .unwrap()
377 .as_agent_message()
378 .unwrap()
379 .content
380 .iter()
381 .any(|content| {
382 if let AgentMessageContent::Text(text) = content {
383 text.contains("Ding")
384 } else {
385 false
386 }
387 }),
388 "{}",
389 thread.to_markdown()
390 );
391 });
392}
393
394#[gpui::test]
395#[cfg_attr(not(feature = "e2e"), ignore)]
396async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
397 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
398
399 // Test a tool call that's likely to complete *before* streaming stops.
400 let mut events = thread
401 .update(cx, |thread, cx| {
402 thread.add_tool(WordListTool);
403 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
404 })
405 .unwrap();
406
407 let mut saw_partial_tool_use = false;
408 while let Some(event) = events.next().await {
409 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
410 thread.update(cx, |thread, _cx| {
411 // Look for a tool use in the thread's last message
412 let message = thread.last_message().unwrap();
413 let agent_message = message.as_agent_message().unwrap();
414 let last_content = agent_message.content.last().unwrap();
415 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
416 assert_eq!(last_tool_use.name.as_ref(), "word_list");
417 if tool_call.status == acp::ToolCallStatus::Pending {
418 if !last_tool_use.is_input_complete
419 && last_tool_use.input.get("g").is_none()
420 {
421 saw_partial_tool_use = true;
422 }
423 } else {
424 last_tool_use
425 .input
426 .get("a")
427 .expect("'a' has streamed because input is now complete");
428 last_tool_use
429 .input
430 .get("g")
431 .expect("'g' has streamed because input is now complete");
432 }
433 } else {
434 panic!("last content should be a tool use");
435 }
436 });
437 }
438 }
439
440 assert!(
441 saw_partial_tool_use,
442 "should see at least one partially streamed tool use in the history"
443 );
444}
445
446#[gpui::test]
447async fn test_tool_authorization(cx: &mut TestAppContext) {
448 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
449 let fake_model = model.as_fake();
450
451 let mut events = thread
452 .update(cx, |thread, cx| {
453 thread.add_tool(ToolRequiringPermission);
454 thread.send(UserMessageId::new(), ["abc"], cx)
455 })
456 .unwrap();
457 cx.run_until_parked();
458 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
459 LanguageModelToolUse {
460 id: "tool_id_1".into(),
461 name: ToolRequiringPermission::name().into(),
462 raw_input: "{}".into(),
463 input: json!({}),
464 is_input_complete: true,
465 },
466 ));
467 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
468 LanguageModelToolUse {
469 id: "tool_id_2".into(),
470 name: ToolRequiringPermission::name().into(),
471 raw_input: "{}".into(),
472 input: json!({}),
473 is_input_complete: true,
474 },
475 ));
476 fake_model.end_last_completion_stream();
477 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
478 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
479
480 // Approve the first
481 tool_call_auth_1
482 .response
483 .send(tool_call_auth_1.options[1].id.clone())
484 .unwrap();
485 cx.run_until_parked();
486
487 // Reject the second
488 tool_call_auth_2
489 .response
490 .send(tool_call_auth_1.options[2].id.clone())
491 .unwrap();
492 cx.run_until_parked();
493
494 let completion = fake_model.pending_completions().pop().unwrap();
495 let message = completion.messages.last().unwrap();
496 assert_eq!(
497 message.content,
498 vec![
499 language_model::MessageContent::ToolResult(LanguageModelToolResult {
500 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
501 tool_name: ToolRequiringPermission::name().into(),
502 is_error: false,
503 content: "Allowed".into(),
504 output: Some("Allowed".into())
505 }),
506 language_model::MessageContent::ToolResult(LanguageModelToolResult {
507 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
508 tool_name: ToolRequiringPermission::name().into(),
509 is_error: true,
510 content: "Permission to run tool denied by user".into(),
511 output: Some("Permission to run tool denied by user".into())
512 })
513 ]
514 );
515
516 // Simulate yet another tool call.
517 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
518 LanguageModelToolUse {
519 id: "tool_id_3".into(),
520 name: ToolRequiringPermission::name().into(),
521 raw_input: "{}".into(),
522 input: json!({}),
523 is_input_complete: true,
524 },
525 ));
526 fake_model.end_last_completion_stream();
527
528 // Respond by always allowing tools.
529 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
530 tool_call_auth_3
531 .response
532 .send(tool_call_auth_3.options[0].id.clone())
533 .unwrap();
534 cx.run_until_parked();
535 let completion = fake_model.pending_completions().pop().unwrap();
536 let message = completion.messages.last().unwrap();
537 assert_eq!(
538 message.content,
539 vec![language_model::MessageContent::ToolResult(
540 LanguageModelToolResult {
541 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
542 tool_name: ToolRequiringPermission::name().into(),
543 is_error: false,
544 content: "Allowed".into(),
545 output: Some("Allowed".into())
546 }
547 )]
548 );
549
550 // Simulate a final tool call, ensuring we don't trigger authorization.
551 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
552 LanguageModelToolUse {
553 id: "tool_id_4".into(),
554 name: ToolRequiringPermission::name().into(),
555 raw_input: "{}".into(),
556 input: json!({}),
557 is_input_complete: true,
558 },
559 ));
560 fake_model.end_last_completion_stream();
561 cx.run_until_parked();
562 let completion = fake_model.pending_completions().pop().unwrap();
563 let message = completion.messages.last().unwrap();
564 assert_eq!(
565 message.content,
566 vec![language_model::MessageContent::ToolResult(
567 LanguageModelToolResult {
568 tool_use_id: "tool_id_4".into(),
569 tool_name: ToolRequiringPermission::name().into(),
570 is_error: false,
571 content: "Allowed".into(),
572 output: Some("Allowed".into())
573 }
574 )]
575 );
576}
577
578#[gpui::test]
579async fn test_tool_hallucination(cx: &mut TestAppContext) {
580 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
581 let fake_model = model.as_fake();
582
583 let mut events = thread
584 .update(cx, |thread, cx| {
585 thread.send(UserMessageId::new(), ["abc"], cx)
586 })
587 .unwrap();
588 cx.run_until_parked();
589 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
590 LanguageModelToolUse {
591 id: "tool_id_1".into(),
592 name: "nonexistent_tool".into(),
593 raw_input: "{}".into(),
594 input: json!({}),
595 is_input_complete: true,
596 },
597 ));
598 fake_model.end_last_completion_stream();
599
600 let tool_call = expect_tool_call(&mut events).await;
601 assert_eq!(tool_call.title, "nonexistent_tool");
602 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
603 let update = expect_tool_call_update_fields(&mut events).await;
604 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
605}
606
607#[gpui::test]
608async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
609 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
610 let fake_model = model.as_fake();
611
612 let events = thread
613 .update(cx, |thread, cx| {
614 thread.add_tool(EchoTool);
615 thread.send(UserMessageId::new(), ["abc"], cx)
616 })
617 .unwrap();
618 cx.run_until_parked();
619 let tool_use = LanguageModelToolUse {
620 id: "tool_id_1".into(),
621 name: EchoTool::name().into(),
622 raw_input: "{}".into(),
623 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
624 is_input_complete: true,
625 };
626 fake_model
627 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
628 fake_model.end_last_completion_stream();
629
630 cx.run_until_parked();
631 let completion = fake_model.pending_completions().pop().unwrap();
632 let tool_result = LanguageModelToolResult {
633 tool_use_id: "tool_id_1".into(),
634 tool_name: EchoTool::name().into(),
635 is_error: false,
636 content: "def".into(),
637 output: Some("def".into()),
638 };
639 assert_eq!(
640 completion.messages[1..],
641 vec![
642 LanguageModelRequestMessage {
643 role: Role::User,
644 content: vec!["abc".into()],
645 cache: false
646 },
647 LanguageModelRequestMessage {
648 role: Role::Assistant,
649 content: vec![MessageContent::ToolUse(tool_use.clone())],
650 cache: false
651 },
652 LanguageModelRequestMessage {
653 role: Role::User,
654 content: vec![MessageContent::ToolResult(tool_result.clone())],
655 cache: true
656 },
657 ]
658 );
659
660 // Simulate reaching tool use limit.
661 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
662 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
663 ));
664 fake_model.end_last_completion_stream();
665 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
666 assert!(
667 last_event
668 .unwrap_err()
669 .is::<language_model::ToolUseLimitReachedError>()
670 );
671
672 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
673 cx.run_until_parked();
674 let completion = fake_model.pending_completions().pop().unwrap();
675 assert_eq!(
676 completion.messages[1..],
677 vec![
678 LanguageModelRequestMessage {
679 role: Role::User,
680 content: vec!["abc".into()],
681 cache: false
682 },
683 LanguageModelRequestMessage {
684 role: Role::Assistant,
685 content: vec![MessageContent::ToolUse(tool_use)],
686 cache: false
687 },
688 LanguageModelRequestMessage {
689 role: Role::User,
690 content: vec![MessageContent::ToolResult(tool_result)],
691 cache: false
692 },
693 LanguageModelRequestMessage {
694 role: Role::User,
695 content: vec!["Continue where you left off".into()],
696 cache: true
697 }
698 ]
699 );
700
701 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
702 fake_model.end_last_completion_stream();
703 events.collect::<Vec<_>>().await;
704 thread.read_with(cx, |thread, _cx| {
705 assert_eq!(
706 thread.last_message().unwrap().to_markdown(),
707 indoc! {"
708 ## Assistant
709
710 Done
711 "}
712 )
713 });
714}
715
716#[gpui::test]
717async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
718 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
719 let fake_model = model.as_fake();
720
721 let events = thread
722 .update(cx, |thread, cx| {
723 thread.add_tool(EchoTool);
724 thread.send(UserMessageId::new(), ["abc"], cx)
725 })
726 .unwrap();
727 cx.run_until_parked();
728
729 let tool_use = LanguageModelToolUse {
730 id: "tool_id_1".into(),
731 name: EchoTool::name().into(),
732 raw_input: "{}".into(),
733 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
734 is_input_complete: true,
735 };
736 let tool_result = LanguageModelToolResult {
737 tool_use_id: "tool_id_1".into(),
738 tool_name: EchoTool::name().into(),
739 is_error: false,
740 content: "def".into(),
741 output: Some("def".into()),
742 };
743 fake_model
744 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
745 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
746 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
747 ));
748 fake_model.end_last_completion_stream();
749 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
750 assert!(
751 last_event
752 .unwrap_err()
753 .is::<language_model::ToolUseLimitReachedError>()
754 );
755
756 thread
757 .update(cx, |thread, cx| {
758 thread.send(UserMessageId::new(), vec!["ghi"], cx)
759 })
760 .unwrap();
761 cx.run_until_parked();
762 let completion = fake_model.pending_completions().pop().unwrap();
763 assert_eq!(
764 completion.messages[1..],
765 vec![
766 LanguageModelRequestMessage {
767 role: Role::User,
768 content: vec!["abc".into()],
769 cache: false
770 },
771 LanguageModelRequestMessage {
772 role: Role::Assistant,
773 content: vec![MessageContent::ToolUse(tool_use)],
774 cache: false
775 },
776 LanguageModelRequestMessage {
777 role: Role::User,
778 content: vec![MessageContent::ToolResult(tool_result)],
779 cache: false
780 },
781 LanguageModelRequestMessage {
782 role: Role::User,
783 content: vec!["ghi".into()],
784 cache: true
785 }
786 ]
787 );
788}
789
790async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
791 let event = events
792 .next()
793 .await
794 .expect("no tool call authorization event received")
795 .unwrap();
796 match event {
797 ThreadEvent::ToolCall(tool_call) => tool_call,
798 event => {
799 panic!("Unexpected event {event:?}");
800 }
801 }
802}
803
804async fn expect_tool_call_update_fields(
805 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
806) -> acp::ToolCallUpdate {
807 let event = events
808 .next()
809 .await
810 .expect("no tool call authorization event received")
811 .unwrap();
812 match event {
813 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
814 event => {
815 panic!("Unexpected event {event:?}");
816 }
817 }
818}
819
820async fn next_tool_call_authorization(
821 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
822) -> ToolCallAuthorization {
823 loop {
824 let event = events
825 .next()
826 .await
827 .expect("no tool call authorization event received")
828 .unwrap();
829 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
830 let permission_kinds = tool_call_authorization
831 .options
832 .iter()
833 .map(|o| o.kind)
834 .collect::<Vec<_>>();
835 assert_eq!(
836 permission_kinds,
837 vec![
838 acp::PermissionOptionKind::AllowAlways,
839 acp::PermissionOptionKind::AllowOnce,
840 acp::PermissionOptionKind::RejectOnce,
841 ]
842 );
843 return tool_call_authorization;
844 }
845 }
846}
847
848#[gpui::test]
849#[cfg_attr(not(feature = "e2e"), ignore)]
850async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
851 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
852
853 // Test concurrent tool calls with different delay times
854 let events = thread
855 .update(cx, |thread, cx| {
856 thread.add_tool(DelayTool);
857 thread.send(
858 UserMessageId::new(),
859 [
860 "Call the delay tool twice in the same message.",
861 "Once with 100ms. Once with 300ms.",
862 "When both timers are complete, describe the outputs.",
863 ],
864 cx,
865 )
866 })
867 .unwrap()
868 .collect()
869 .await;
870
871 let stop_reasons = stop_events(events);
872 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
873
874 thread.update(cx, |thread, _cx| {
875 let last_message = thread.last_message().unwrap();
876 let agent_message = last_message.as_agent_message().unwrap();
877 let text = agent_message
878 .content
879 .iter()
880 .filter_map(|content| {
881 if let AgentMessageContent::Text(text) = content {
882 Some(text.as_str())
883 } else {
884 None
885 }
886 })
887 .collect::<String>();
888
889 assert!(text.contains("Ding"));
890 });
891}
892
893#[gpui::test]
894async fn test_profiles(cx: &mut TestAppContext) {
895 let ThreadTest {
896 model, thread, fs, ..
897 } = setup(cx, TestModel::Fake).await;
898 let fake_model = model.as_fake();
899
900 thread.update(cx, |thread, _cx| {
901 thread.add_tool(DelayTool);
902 thread.add_tool(EchoTool);
903 thread.add_tool(InfiniteTool);
904 });
905
906 // Override profiles and wait for settings to be loaded.
907 fs.insert_file(
908 paths::settings_file(),
909 json!({
910 "agent": {
911 "profiles": {
912 "test-1": {
913 "name": "Test Profile 1",
914 "tools": {
915 EchoTool::name(): true,
916 DelayTool::name(): true,
917 }
918 },
919 "test-2": {
920 "name": "Test Profile 2",
921 "tools": {
922 InfiniteTool::name(): true,
923 }
924 }
925 }
926 }
927 })
928 .to_string()
929 .into_bytes(),
930 )
931 .await;
932 cx.run_until_parked();
933
934 // Test that test-1 profile (default) has echo and delay tools
935 thread
936 .update(cx, |thread, cx| {
937 thread.set_profile(AgentProfileId("test-1".into()), cx);
938 thread.send(UserMessageId::new(), ["test"], cx)
939 })
940 .unwrap();
941 cx.run_until_parked();
942
943 let mut pending_completions = fake_model.pending_completions();
944 assert_eq!(pending_completions.len(), 1);
945 let completion = pending_completions.pop().unwrap();
946 let tool_names: Vec<String> = completion
947 .tools
948 .iter()
949 .map(|tool| tool.name.clone())
950 .collect();
951 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
952 fake_model.end_last_completion_stream();
953
954 // Switch to test-2 profile, and verify that it has only the infinite tool.
955 thread
956 .update(cx, |thread, cx| {
957 thread.set_profile(AgentProfileId("test-2".into()), cx);
958 thread.send(UserMessageId::new(), ["test2"], cx)
959 })
960 .unwrap();
961 cx.run_until_parked();
962 let mut pending_completions = fake_model.pending_completions();
963 assert_eq!(pending_completions.len(), 1);
964 let completion = pending_completions.pop().unwrap();
965 let tool_names: Vec<String> = completion
966 .tools
967 .iter()
968 .map(|tool| tool.name.clone())
969 .collect();
970 assert_eq!(tool_names, vec![InfiniteTool::name()]);
971}
972
973#[gpui::test]
974async fn test_mcp_tools(cx: &mut TestAppContext) {
975 let ThreadTest {
976 model,
977 thread,
978 context_server_store,
979 fs,
980 ..
981 } = setup(cx, TestModel::Fake).await;
982 let fake_model = model.as_fake();
983
984 // Override profiles and wait for settings to be loaded.
985 fs.insert_file(
986 paths::settings_file(),
987 json!({
988 "agent": {
989 "always_allow_tool_actions": true,
990 "profiles": {
991 "test": {
992 "name": "Test Profile",
993 "enable_all_context_servers": true,
994 "tools": {
995 EchoTool::name(): true,
996 }
997 },
998 }
999 }
1000 })
1001 .to_string()
1002 .into_bytes(),
1003 )
1004 .await;
1005 cx.run_until_parked();
1006 thread.update(cx, |thread, cx| {
1007 thread.set_profile(AgentProfileId("test".into()), cx)
1008 });
1009
1010 let mut mcp_tool_calls = setup_context_server(
1011 "test_server",
1012 vec![context_server::types::Tool {
1013 name: "echo".into(),
1014 description: None,
1015 input_schema: serde_json::to_value(EchoTool::input_schema(
1016 LanguageModelToolSchemaFormat::JsonSchema,
1017 ))
1018 .unwrap(),
1019 output_schema: None,
1020 annotations: None,
1021 }],
1022 &context_server_store,
1023 cx,
1024 );
1025
1026 let events = thread.update(cx, |thread, cx| {
1027 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1028 });
1029 cx.run_until_parked();
1030
1031 // Simulate the model calling the MCP tool.
1032 let completion = fake_model.pending_completions().pop().unwrap();
1033 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1034 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1035 LanguageModelToolUse {
1036 id: "tool_1".into(),
1037 name: "echo".into(),
1038 raw_input: json!({"text": "test"}).to_string(),
1039 input: json!({"text": "test"}),
1040 is_input_complete: true,
1041 },
1042 ));
1043 fake_model.end_last_completion_stream();
1044 cx.run_until_parked();
1045
1046 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1047 assert_eq!(tool_call_params.name, "echo");
1048 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1049 tool_call_response
1050 .send(context_server::types::CallToolResponse {
1051 content: vec![context_server::types::ToolResponseContent::Text {
1052 text: "test".into(),
1053 }],
1054 is_error: None,
1055 meta: None,
1056 structured_content: None,
1057 })
1058 .unwrap();
1059 cx.run_until_parked();
1060
1061 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1062 fake_model.send_last_completion_stream_text_chunk("Done!");
1063 fake_model.end_last_completion_stream();
1064 events.collect::<Vec<_>>().await;
1065
1066 // Send again after adding the echo tool, ensuring the name collision is resolved.
1067 let events = thread.update(cx, |thread, cx| {
1068 thread.add_tool(EchoTool);
1069 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1070 });
1071 cx.run_until_parked();
1072 let completion = fake_model.pending_completions().pop().unwrap();
1073 assert_eq!(
1074 tool_names_for_completion(&completion),
1075 vec!["echo", "test_server_echo"]
1076 );
1077 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1078 LanguageModelToolUse {
1079 id: "tool_2".into(),
1080 name: "test_server_echo".into(),
1081 raw_input: json!({"text": "mcp"}).to_string(),
1082 input: json!({"text": "mcp"}),
1083 is_input_complete: true,
1084 },
1085 ));
1086 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1087 LanguageModelToolUse {
1088 id: "tool_3".into(),
1089 name: "echo".into(),
1090 raw_input: json!({"text": "native"}).to_string(),
1091 input: json!({"text": "native"}),
1092 is_input_complete: true,
1093 },
1094 ));
1095 fake_model.end_last_completion_stream();
1096 cx.run_until_parked();
1097
1098 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1099 assert_eq!(tool_call_params.name, "echo");
1100 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1101 tool_call_response
1102 .send(context_server::types::CallToolResponse {
1103 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1104 is_error: None,
1105 meta: None,
1106 structured_content: None,
1107 })
1108 .unwrap();
1109 cx.run_until_parked();
1110
1111 // Ensure the tool results were inserted with the correct names.
1112 let completion = fake_model.pending_completions().pop().unwrap();
1113 assert_eq!(
1114 completion.messages.last().unwrap().content,
1115 vec![
1116 MessageContent::ToolResult(LanguageModelToolResult {
1117 tool_use_id: "tool_3".into(),
1118 tool_name: "echo".into(),
1119 is_error: false,
1120 content: "native".into(),
1121 output: Some("native".into()),
1122 },),
1123 MessageContent::ToolResult(LanguageModelToolResult {
1124 tool_use_id: "tool_2".into(),
1125 tool_name: "test_server_echo".into(),
1126 is_error: false,
1127 content: "mcp".into(),
1128 output: Some("mcp".into()),
1129 },),
1130 ]
1131 );
1132 fake_model.end_last_completion_stream();
1133 events.collect::<Vec<_>>().await;
1134}
1135
1136#[gpui::test]
1137async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1138 let ThreadTest {
1139 model,
1140 thread,
1141 context_server_store,
1142 fs,
1143 ..
1144 } = setup(cx, TestModel::Fake).await;
1145 let fake_model = model.as_fake();
1146
1147 // Set up a profile with all tools enabled
1148 fs.insert_file(
1149 paths::settings_file(),
1150 json!({
1151 "agent": {
1152 "profiles": {
1153 "test": {
1154 "name": "Test Profile",
1155 "enable_all_context_servers": true,
1156 "tools": {
1157 EchoTool::name(): true,
1158 DelayTool::name(): true,
1159 WordListTool::name(): true,
1160 ToolRequiringPermission::name(): true,
1161 InfiniteTool::name(): true,
1162 }
1163 },
1164 }
1165 }
1166 })
1167 .to_string()
1168 .into_bytes(),
1169 )
1170 .await;
1171 cx.run_until_parked();
1172
1173 thread.update(cx, |thread, cx| {
1174 thread.set_profile(AgentProfileId("test".into()), cx);
1175 thread.add_tool(EchoTool);
1176 thread.add_tool(DelayTool);
1177 thread.add_tool(WordListTool);
1178 thread.add_tool(ToolRequiringPermission);
1179 thread.add_tool(InfiniteTool);
1180 });
1181
1182 // Set up multiple context servers with some overlapping tool names
1183 let _server1_calls = setup_context_server(
1184 "xxx",
1185 vec![
1186 context_server::types::Tool {
1187 name: "echo".into(), // Conflicts with native EchoTool
1188 description: None,
1189 input_schema: serde_json::to_value(EchoTool::input_schema(
1190 LanguageModelToolSchemaFormat::JsonSchema,
1191 ))
1192 .unwrap(),
1193 output_schema: None,
1194 annotations: None,
1195 },
1196 context_server::types::Tool {
1197 name: "unique_tool_1".into(),
1198 description: None,
1199 input_schema: json!({"type": "object", "properties": {}}),
1200 output_schema: None,
1201 annotations: None,
1202 },
1203 ],
1204 &context_server_store,
1205 cx,
1206 );
1207
1208 let _server2_calls = setup_context_server(
1209 "yyy",
1210 vec![
1211 context_server::types::Tool {
1212 name: "echo".into(), // Also conflicts with native EchoTool
1213 description: None,
1214 input_schema: serde_json::to_value(EchoTool::input_schema(
1215 LanguageModelToolSchemaFormat::JsonSchema,
1216 ))
1217 .unwrap(),
1218 output_schema: None,
1219 annotations: None,
1220 },
1221 context_server::types::Tool {
1222 name: "unique_tool_2".into(),
1223 description: None,
1224 input_schema: json!({"type": "object", "properties": {}}),
1225 output_schema: None,
1226 annotations: None,
1227 },
1228 context_server::types::Tool {
1229 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1230 description: None,
1231 input_schema: json!({"type": "object", "properties": {}}),
1232 output_schema: None,
1233 annotations: None,
1234 },
1235 context_server::types::Tool {
1236 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1237 description: None,
1238 input_schema: json!({"type": "object", "properties": {}}),
1239 output_schema: None,
1240 annotations: None,
1241 },
1242 ],
1243 &context_server_store,
1244 cx,
1245 );
1246 let _server3_calls = setup_context_server(
1247 "zzz",
1248 vec![
1249 context_server::types::Tool {
1250 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1251 description: None,
1252 input_schema: json!({"type": "object", "properties": {}}),
1253 output_schema: None,
1254 annotations: None,
1255 },
1256 context_server::types::Tool {
1257 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1258 description: None,
1259 input_schema: json!({"type": "object", "properties": {}}),
1260 output_schema: None,
1261 annotations: None,
1262 },
1263 context_server::types::Tool {
1264 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1265 description: None,
1266 input_schema: json!({"type": "object", "properties": {}}),
1267 output_schema: None,
1268 annotations: None,
1269 },
1270 ],
1271 &context_server_store,
1272 cx,
1273 );
1274
1275 thread
1276 .update(cx, |thread, cx| {
1277 thread.send(UserMessageId::new(), ["Go"], cx)
1278 })
1279 .unwrap();
1280 cx.run_until_parked();
1281 let completion = fake_model.pending_completions().pop().unwrap();
1282 assert_eq!(
1283 tool_names_for_completion(&completion),
1284 vec![
1285 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1286 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1287 "delay",
1288 "echo",
1289 "infinite",
1290 "tool_requiring_permission",
1291 "unique_tool_1",
1292 "unique_tool_2",
1293 "word_list",
1294 "xxx_echo",
1295 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1296 "yyy_echo",
1297 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1298 ]
1299 );
1300}
1301
1302#[gpui::test]
1303#[cfg_attr(not(feature = "e2e"), ignore)]
1304async fn test_cancellation(cx: &mut TestAppContext) {
1305 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1306
1307 let mut events = thread
1308 .update(cx, |thread, cx| {
1309 thread.add_tool(InfiniteTool);
1310 thread.add_tool(EchoTool);
1311 thread.send(
1312 UserMessageId::new(),
1313 ["Call the echo tool, then call the infinite tool, then explain their output"],
1314 cx,
1315 )
1316 })
1317 .unwrap();
1318
1319 // Wait until both tools are called.
1320 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1321 let mut echo_id = None;
1322 let mut echo_completed = false;
1323 while let Some(event) = events.next().await {
1324 match event.unwrap() {
1325 ThreadEvent::ToolCall(tool_call) => {
1326 assert_eq!(tool_call.title, expected_tools.remove(0));
1327 if tool_call.title == "Echo" {
1328 echo_id = Some(tool_call.id);
1329 }
1330 }
1331 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1332 acp::ToolCallUpdate {
1333 id,
1334 fields:
1335 acp::ToolCallUpdateFields {
1336 status: Some(acp::ToolCallStatus::Completed),
1337 ..
1338 },
1339 meta: None,
1340 },
1341 )) if Some(&id) == echo_id.as_ref() => {
1342 echo_completed = true;
1343 }
1344 _ => {}
1345 }
1346
1347 if expected_tools.is_empty() && echo_completed {
1348 break;
1349 }
1350 }
1351
1352 // Cancel the current send and ensure that the event stream is closed, even
1353 // if one of the tools is still running.
1354 thread.update(cx, |thread, cx| thread.cancel(cx));
1355 let events = events.collect::<Vec<_>>().await;
1356 let last_event = events.last();
1357 assert!(
1358 matches!(
1359 last_event,
1360 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1361 ),
1362 "unexpected event {last_event:?}"
1363 );
1364
1365 // Ensure we can still send a new message after cancellation.
1366 let events = thread
1367 .update(cx, |thread, cx| {
1368 thread.send(
1369 UserMessageId::new(),
1370 ["Testing: reply with 'Hello' then stop."],
1371 cx,
1372 )
1373 })
1374 .unwrap()
1375 .collect::<Vec<_>>()
1376 .await;
1377 thread.update(cx, |thread, _cx| {
1378 let message = thread.last_message().unwrap();
1379 let agent_message = message.as_agent_message().unwrap();
1380 assert_eq!(
1381 agent_message.content,
1382 vec![AgentMessageContent::Text("Hello".to_string())]
1383 );
1384 });
1385 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1386}
1387
1388#[gpui::test]
1389async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1390 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1391 let fake_model = model.as_fake();
1392
1393 let events_1 = thread
1394 .update(cx, |thread, cx| {
1395 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1396 })
1397 .unwrap();
1398 cx.run_until_parked();
1399 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1400 cx.run_until_parked();
1401
1402 let events_2 = thread
1403 .update(cx, |thread, cx| {
1404 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1405 })
1406 .unwrap();
1407 cx.run_until_parked();
1408 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1409 fake_model
1410 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1411 fake_model.end_last_completion_stream();
1412
1413 let events_1 = events_1.collect::<Vec<_>>().await;
1414 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1415 let events_2 = events_2.collect::<Vec<_>>().await;
1416 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1417}
1418
1419#[gpui::test]
1420async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1421 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1422 let fake_model = model.as_fake();
1423
1424 let events_1 = thread
1425 .update(cx, |thread, cx| {
1426 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1427 })
1428 .unwrap();
1429 cx.run_until_parked();
1430 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1431 fake_model
1432 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1433 fake_model.end_last_completion_stream();
1434 let events_1 = events_1.collect::<Vec<_>>().await;
1435
1436 let events_2 = thread
1437 .update(cx, |thread, cx| {
1438 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1439 })
1440 .unwrap();
1441 cx.run_until_parked();
1442 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1443 fake_model
1444 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1445 fake_model.end_last_completion_stream();
1446 let events_2 = events_2.collect::<Vec<_>>().await;
1447
1448 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1449 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1450}
1451
1452#[gpui::test]
1453async fn test_refusal(cx: &mut TestAppContext) {
1454 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1455 let fake_model = model.as_fake();
1456
1457 let events = thread
1458 .update(cx, |thread, cx| {
1459 thread.send(UserMessageId::new(), ["Hello"], cx)
1460 })
1461 .unwrap();
1462 cx.run_until_parked();
1463 thread.read_with(cx, |thread, _| {
1464 assert_eq!(
1465 thread.to_markdown(),
1466 indoc! {"
1467 ## User
1468
1469 Hello
1470 "}
1471 );
1472 });
1473
1474 fake_model.send_last_completion_stream_text_chunk("Hey!");
1475 cx.run_until_parked();
1476 thread.read_with(cx, |thread, _| {
1477 assert_eq!(
1478 thread.to_markdown(),
1479 indoc! {"
1480 ## User
1481
1482 Hello
1483
1484 ## Assistant
1485
1486 Hey!
1487 "}
1488 );
1489 });
1490
1491 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1492 fake_model
1493 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1494 let events = events.collect::<Vec<_>>().await;
1495 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1496 thread.read_with(cx, |thread, _| {
1497 assert_eq!(thread.to_markdown(), "");
1498 });
1499}
1500
1501#[gpui::test]
1502async fn test_truncate_first_message(cx: &mut TestAppContext) {
1503 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1504 let fake_model = model.as_fake();
1505
1506 let message_id = UserMessageId::new();
1507 thread
1508 .update(cx, |thread, cx| {
1509 thread.send(message_id.clone(), ["Hello"], cx)
1510 })
1511 .unwrap();
1512 cx.run_until_parked();
1513 thread.read_with(cx, |thread, _| {
1514 assert_eq!(
1515 thread.to_markdown(),
1516 indoc! {"
1517 ## User
1518
1519 Hello
1520 "}
1521 );
1522 assert_eq!(thread.latest_token_usage(), None);
1523 });
1524
1525 fake_model.send_last_completion_stream_text_chunk("Hey!");
1526 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1527 language_model::TokenUsage {
1528 input_tokens: 32_000,
1529 output_tokens: 16_000,
1530 cache_creation_input_tokens: 0,
1531 cache_read_input_tokens: 0,
1532 },
1533 ));
1534 cx.run_until_parked();
1535 thread.read_with(cx, |thread, _| {
1536 assert_eq!(
1537 thread.to_markdown(),
1538 indoc! {"
1539 ## User
1540
1541 Hello
1542
1543 ## Assistant
1544
1545 Hey!
1546 "}
1547 );
1548 assert_eq!(
1549 thread.latest_token_usage(),
1550 Some(acp_thread::TokenUsage {
1551 used_tokens: 32_000 + 16_000,
1552 max_tokens: 1_000_000,
1553 })
1554 );
1555 });
1556
1557 thread
1558 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1559 .unwrap();
1560 cx.run_until_parked();
1561 thread.read_with(cx, |thread, _| {
1562 assert_eq!(thread.to_markdown(), "");
1563 assert_eq!(thread.latest_token_usage(), None);
1564 });
1565
1566 // Ensure we can still send a new message after truncation.
1567 thread
1568 .update(cx, |thread, cx| {
1569 thread.send(UserMessageId::new(), ["Hi"], cx)
1570 })
1571 .unwrap();
1572 thread.update(cx, |thread, _cx| {
1573 assert_eq!(
1574 thread.to_markdown(),
1575 indoc! {"
1576 ## User
1577
1578 Hi
1579 "}
1580 );
1581 });
1582 cx.run_until_parked();
1583 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1584 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1585 language_model::TokenUsage {
1586 input_tokens: 40_000,
1587 output_tokens: 20_000,
1588 cache_creation_input_tokens: 0,
1589 cache_read_input_tokens: 0,
1590 },
1591 ));
1592 cx.run_until_parked();
1593 thread.read_with(cx, |thread, _| {
1594 assert_eq!(
1595 thread.to_markdown(),
1596 indoc! {"
1597 ## User
1598
1599 Hi
1600
1601 ## Assistant
1602
1603 Ahoy!
1604 "}
1605 );
1606
1607 assert_eq!(
1608 thread.latest_token_usage(),
1609 Some(acp_thread::TokenUsage {
1610 used_tokens: 40_000 + 20_000,
1611 max_tokens: 1_000_000,
1612 })
1613 );
1614 });
1615}
1616
1617#[gpui::test]
1618async fn test_truncate_second_message(cx: &mut TestAppContext) {
1619 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1620 let fake_model = model.as_fake();
1621
1622 thread
1623 .update(cx, |thread, cx| {
1624 thread.send(UserMessageId::new(), ["Message 1"], cx)
1625 })
1626 .unwrap();
1627 cx.run_until_parked();
1628 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1629 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1630 language_model::TokenUsage {
1631 input_tokens: 32_000,
1632 output_tokens: 16_000,
1633 cache_creation_input_tokens: 0,
1634 cache_read_input_tokens: 0,
1635 },
1636 ));
1637 fake_model.end_last_completion_stream();
1638 cx.run_until_parked();
1639
1640 let assert_first_message_state = |cx: &mut TestAppContext| {
1641 thread.clone().read_with(cx, |thread, _| {
1642 assert_eq!(
1643 thread.to_markdown(),
1644 indoc! {"
1645 ## User
1646
1647 Message 1
1648
1649 ## Assistant
1650
1651 Message 1 response
1652 "}
1653 );
1654
1655 assert_eq!(
1656 thread.latest_token_usage(),
1657 Some(acp_thread::TokenUsage {
1658 used_tokens: 32_000 + 16_000,
1659 max_tokens: 1_000_000,
1660 })
1661 );
1662 });
1663 };
1664
1665 assert_first_message_state(cx);
1666
1667 let second_message_id = UserMessageId::new();
1668 thread
1669 .update(cx, |thread, cx| {
1670 thread.send(second_message_id.clone(), ["Message 2"], cx)
1671 })
1672 .unwrap();
1673 cx.run_until_parked();
1674
1675 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1676 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1677 language_model::TokenUsage {
1678 input_tokens: 40_000,
1679 output_tokens: 20_000,
1680 cache_creation_input_tokens: 0,
1681 cache_read_input_tokens: 0,
1682 },
1683 ));
1684 fake_model.end_last_completion_stream();
1685 cx.run_until_parked();
1686
1687 thread.read_with(cx, |thread, _| {
1688 assert_eq!(
1689 thread.to_markdown(),
1690 indoc! {"
1691 ## User
1692
1693 Message 1
1694
1695 ## Assistant
1696
1697 Message 1 response
1698
1699 ## User
1700
1701 Message 2
1702
1703 ## Assistant
1704
1705 Message 2 response
1706 "}
1707 );
1708
1709 assert_eq!(
1710 thread.latest_token_usage(),
1711 Some(acp_thread::TokenUsage {
1712 used_tokens: 40_000 + 20_000,
1713 max_tokens: 1_000_000,
1714 })
1715 );
1716 });
1717
1718 thread
1719 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1720 .unwrap();
1721 cx.run_until_parked();
1722
1723 assert_first_message_state(cx);
1724}
1725
1726#[gpui::test]
1727async fn test_title_generation(cx: &mut TestAppContext) {
1728 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1729 let fake_model = model.as_fake();
1730
1731 let summary_model = Arc::new(FakeLanguageModel::default());
1732 thread.update(cx, |thread, cx| {
1733 thread.set_summarization_model(Some(summary_model.clone()), cx)
1734 });
1735
1736 let send = thread
1737 .update(cx, |thread, cx| {
1738 thread.send(UserMessageId::new(), ["Hello"], cx)
1739 })
1740 .unwrap();
1741 cx.run_until_parked();
1742
1743 fake_model.send_last_completion_stream_text_chunk("Hey!");
1744 fake_model.end_last_completion_stream();
1745 cx.run_until_parked();
1746 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1747
1748 // Ensure the summary model has been invoked to generate a title.
1749 summary_model.send_last_completion_stream_text_chunk("Hello ");
1750 summary_model.send_last_completion_stream_text_chunk("world\nG");
1751 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1752 summary_model.end_last_completion_stream();
1753 send.collect::<Vec<_>>().await;
1754 cx.run_until_parked();
1755 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1756
1757 // Send another message, ensuring no title is generated this time.
1758 let send = thread
1759 .update(cx, |thread, cx| {
1760 thread.send(UserMessageId::new(), ["Hello again"], cx)
1761 })
1762 .unwrap();
1763 cx.run_until_parked();
1764 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1765 fake_model.end_last_completion_stream();
1766 cx.run_until_parked();
1767 assert_eq!(summary_model.pending_completions(), Vec::new());
1768 send.collect::<Vec<_>>().await;
1769 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1770}
1771
1772#[gpui::test]
1773async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1774 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1775 let fake_model = model.as_fake();
1776
1777 let _events = thread
1778 .update(cx, |thread, cx| {
1779 thread.add_tool(ToolRequiringPermission);
1780 thread.add_tool(EchoTool);
1781 thread.send(UserMessageId::new(), ["Hey!"], cx)
1782 })
1783 .unwrap();
1784 cx.run_until_parked();
1785
1786 let permission_tool_use = LanguageModelToolUse {
1787 id: "tool_id_1".into(),
1788 name: ToolRequiringPermission::name().into(),
1789 raw_input: "{}".into(),
1790 input: json!({}),
1791 is_input_complete: true,
1792 };
1793 let echo_tool_use = LanguageModelToolUse {
1794 id: "tool_id_2".into(),
1795 name: EchoTool::name().into(),
1796 raw_input: json!({"text": "test"}).to_string(),
1797 input: json!({"text": "test"}),
1798 is_input_complete: true,
1799 };
1800 fake_model.send_last_completion_stream_text_chunk("Hi!");
1801 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1802 permission_tool_use,
1803 ));
1804 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1805 echo_tool_use.clone(),
1806 ));
1807 fake_model.end_last_completion_stream();
1808 cx.run_until_parked();
1809
1810 // Ensure pending tools are skipped when building a request.
1811 let request = thread
1812 .read_with(cx, |thread, cx| {
1813 thread.build_completion_request(CompletionIntent::EditFile, cx)
1814 })
1815 .unwrap();
1816 assert_eq!(
1817 request.messages[1..],
1818 vec![
1819 LanguageModelRequestMessage {
1820 role: Role::User,
1821 content: vec!["Hey!".into()],
1822 cache: true
1823 },
1824 LanguageModelRequestMessage {
1825 role: Role::Assistant,
1826 content: vec![
1827 MessageContent::Text("Hi!".into()),
1828 MessageContent::ToolUse(echo_tool_use.clone())
1829 ],
1830 cache: false
1831 },
1832 LanguageModelRequestMessage {
1833 role: Role::User,
1834 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1835 tool_use_id: echo_tool_use.id.clone(),
1836 tool_name: echo_tool_use.name,
1837 is_error: false,
1838 content: "test".into(),
1839 output: Some("test".into())
1840 })],
1841 cache: false
1842 },
1843 ],
1844 );
1845}
1846
1847#[gpui::test]
1848async fn test_agent_connection(cx: &mut TestAppContext) {
1849 cx.update(settings::init);
1850 let templates = Templates::new();
1851
1852 // Initialize language model system with test provider
1853 cx.update(|cx| {
1854 gpui_tokio::init(cx);
1855
1856 let http_client = FakeHttpClient::with_404_response();
1857 let clock = Arc::new(clock::FakeSystemClock::new());
1858 let client = Client::new(clock, http_client, cx);
1859 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1860 language_model::init(client.clone(), cx);
1861 language_models::init(user_store, client.clone(), cx);
1862 LanguageModelRegistry::test(cx);
1863 });
1864 cx.executor().forbid_parking();
1865
1866 // Create a project for new_thread
1867 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1868 fake_fs.insert_tree(path!("/test"), json!({})).await;
1869 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1870 let cwd = Path::new("/test");
1871 let text_thread_store =
1872 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1873 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1874
1875 // Create agent and connection
1876 let agent = NativeAgent::new(
1877 project.clone(),
1878 history_store,
1879 templates.clone(),
1880 None,
1881 fake_fs.clone(),
1882 &mut cx.to_async(),
1883 )
1884 .await
1885 .unwrap();
1886 let connection = NativeAgentConnection(agent.clone());
1887
1888 // Create a thread using new_thread
1889 let connection_rc = Rc::new(connection.clone());
1890 let acp_thread = cx
1891 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1892 .await
1893 .expect("new_thread should succeed");
1894
1895 // Get the session_id from the AcpThread
1896 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1897
1898 // Test model_selector returns Some
1899 let selector_opt = connection.model_selector(&session_id);
1900 assert!(
1901 selector_opt.is_some(),
1902 "agent should always support ModelSelector"
1903 );
1904 let selector = selector_opt.unwrap();
1905
1906 // Test list_models
1907 let listed_models = cx
1908 .update(|cx| selector.list_models(cx))
1909 .await
1910 .expect("list_models should succeed");
1911 let AgentModelList::Grouped(listed_models) = listed_models else {
1912 panic!("Unexpected model list type");
1913 };
1914 assert!(!listed_models.is_empty(), "should have at least one model");
1915 assert_eq!(
1916 listed_models[&AgentModelGroupName("Fake".into())][0]
1917 .id
1918 .0
1919 .as_ref(),
1920 "fake/fake"
1921 );
1922
1923 // Test selected_model returns the default
1924 let model = cx
1925 .update(|cx| selector.selected_model(cx))
1926 .await
1927 .expect("selected_model should succeed");
1928 let model = cx
1929 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1930 .unwrap();
1931 let model = model.as_fake();
1932 assert_eq!(model.id().0, "fake", "should return default model");
1933
1934 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1935 cx.run_until_parked();
1936 model.send_last_completion_stream_text_chunk("def");
1937 cx.run_until_parked();
1938 acp_thread.read_with(cx, |thread, cx| {
1939 assert_eq!(
1940 thread.to_markdown(cx),
1941 indoc! {"
1942 ## User
1943
1944 abc
1945
1946 ## Assistant
1947
1948 def
1949
1950 "}
1951 )
1952 });
1953
1954 // Test cancel
1955 cx.update(|cx| connection.cancel(&session_id, cx));
1956 request.await.expect("prompt should fail gracefully");
1957
1958 // Ensure that dropping the ACP thread causes the native thread to be
1959 // dropped as well.
1960 cx.update(|_| drop(acp_thread));
1961 let result = cx
1962 .update(|cx| {
1963 connection.prompt(
1964 Some(acp_thread::UserMessageId::new()),
1965 acp::PromptRequest {
1966 session_id: session_id.clone(),
1967 prompt: vec!["ghi".into()],
1968 meta: None,
1969 },
1970 cx,
1971 )
1972 })
1973 .await;
1974 assert_eq!(
1975 result.as_ref().unwrap_err().to_string(),
1976 "Session not found",
1977 "unexpected result: {:?}",
1978 result
1979 );
1980}
1981
1982#[gpui::test]
1983async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1984 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1985 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1986 let fake_model = model.as_fake();
1987
1988 let mut events = thread
1989 .update(cx, |thread, cx| {
1990 thread.send(UserMessageId::new(), ["Think"], cx)
1991 })
1992 .unwrap();
1993 cx.run_until_parked();
1994
1995 // Simulate streaming partial input.
1996 let input = json!({});
1997 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1998 LanguageModelToolUse {
1999 id: "1".into(),
2000 name: ThinkingTool::name().into(),
2001 raw_input: input.to_string(),
2002 input,
2003 is_input_complete: false,
2004 },
2005 ));
2006
2007 // Input streaming completed
2008 let input = json!({ "content": "Thinking hard!" });
2009 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2010 LanguageModelToolUse {
2011 id: "1".into(),
2012 name: "thinking".into(),
2013 raw_input: input.to_string(),
2014 input,
2015 is_input_complete: true,
2016 },
2017 ));
2018 fake_model.end_last_completion_stream();
2019 cx.run_until_parked();
2020
2021 let tool_call = expect_tool_call(&mut events).await;
2022 assert_eq!(
2023 tool_call,
2024 acp::ToolCall {
2025 id: acp::ToolCallId("1".into()),
2026 title: "Thinking".into(),
2027 kind: acp::ToolKind::Think,
2028 status: acp::ToolCallStatus::Pending,
2029 content: vec![],
2030 locations: vec![],
2031 raw_input: Some(json!({})),
2032 raw_output: None,
2033 meta: Some(json!({ "tool_name": "thinking" })),
2034 }
2035 );
2036 let update = expect_tool_call_update_fields(&mut events).await;
2037 assert_eq!(
2038 update,
2039 acp::ToolCallUpdate {
2040 id: acp::ToolCallId("1".into()),
2041 fields: acp::ToolCallUpdateFields {
2042 title: Some("Thinking".into()),
2043 kind: Some(acp::ToolKind::Think),
2044 raw_input: Some(json!({ "content": "Thinking hard!" })),
2045 ..Default::default()
2046 },
2047 meta: None,
2048 }
2049 );
2050 let update = expect_tool_call_update_fields(&mut events).await;
2051 assert_eq!(
2052 update,
2053 acp::ToolCallUpdate {
2054 id: acp::ToolCallId("1".into()),
2055 fields: acp::ToolCallUpdateFields {
2056 status: Some(acp::ToolCallStatus::InProgress),
2057 ..Default::default()
2058 },
2059 meta: None,
2060 }
2061 );
2062 let update = expect_tool_call_update_fields(&mut events).await;
2063 assert_eq!(
2064 update,
2065 acp::ToolCallUpdate {
2066 id: acp::ToolCallId("1".into()),
2067 fields: acp::ToolCallUpdateFields {
2068 content: Some(vec!["Thinking hard!".into()]),
2069 ..Default::default()
2070 },
2071 meta: None,
2072 }
2073 );
2074 let update = expect_tool_call_update_fields(&mut events).await;
2075 assert_eq!(
2076 update,
2077 acp::ToolCallUpdate {
2078 id: acp::ToolCallId("1".into()),
2079 fields: acp::ToolCallUpdateFields {
2080 status: Some(acp::ToolCallStatus::Completed),
2081 raw_output: Some("Finished thinking.".into()),
2082 ..Default::default()
2083 },
2084 meta: None,
2085 }
2086 );
2087}
2088
2089#[gpui::test]
2090async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2091 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2092 let fake_model = model.as_fake();
2093
2094 let mut events = thread
2095 .update(cx, |thread, cx| {
2096 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2097 thread.send(UserMessageId::new(), ["Hello!"], cx)
2098 })
2099 .unwrap();
2100 cx.run_until_parked();
2101
2102 fake_model.send_last_completion_stream_text_chunk("Hey!");
2103 fake_model.end_last_completion_stream();
2104
2105 let mut retry_events = Vec::new();
2106 while let Some(Ok(event)) = events.next().await {
2107 match event {
2108 ThreadEvent::Retry(retry_status) => {
2109 retry_events.push(retry_status);
2110 }
2111 ThreadEvent::Stop(..) => break,
2112 _ => {}
2113 }
2114 }
2115
2116 assert_eq!(retry_events.len(), 0);
2117 thread.read_with(cx, |thread, _cx| {
2118 assert_eq!(
2119 thread.to_markdown(),
2120 indoc! {"
2121 ## User
2122
2123 Hello!
2124
2125 ## Assistant
2126
2127 Hey!
2128 "}
2129 )
2130 });
2131}
2132
2133#[gpui::test]
2134async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2135 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2136 let fake_model = model.as_fake();
2137
2138 let mut events = thread
2139 .update(cx, |thread, cx| {
2140 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2141 thread.send(UserMessageId::new(), ["Hello!"], cx)
2142 })
2143 .unwrap();
2144 cx.run_until_parked();
2145
2146 fake_model.send_last_completion_stream_text_chunk("Hey,");
2147 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2148 provider: LanguageModelProviderName::new("Anthropic"),
2149 retry_after: Some(Duration::from_secs(3)),
2150 });
2151 fake_model.end_last_completion_stream();
2152
2153 cx.executor().advance_clock(Duration::from_secs(3));
2154 cx.run_until_parked();
2155
2156 fake_model.send_last_completion_stream_text_chunk("there!");
2157 fake_model.end_last_completion_stream();
2158 cx.run_until_parked();
2159
2160 let mut retry_events = Vec::new();
2161 while let Some(Ok(event)) = events.next().await {
2162 match event {
2163 ThreadEvent::Retry(retry_status) => {
2164 retry_events.push(retry_status);
2165 }
2166 ThreadEvent::Stop(..) => break,
2167 _ => {}
2168 }
2169 }
2170
2171 assert_eq!(retry_events.len(), 1);
2172 assert!(matches!(
2173 retry_events[0],
2174 acp_thread::RetryStatus { attempt: 1, .. }
2175 ));
2176 thread.read_with(cx, |thread, _cx| {
2177 assert_eq!(
2178 thread.to_markdown(),
2179 indoc! {"
2180 ## User
2181
2182 Hello!
2183
2184 ## Assistant
2185
2186 Hey,
2187
2188 [resume]
2189
2190 ## Assistant
2191
2192 there!
2193 "}
2194 )
2195 });
2196}
2197
2198#[gpui::test]
2199async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2200 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2201 let fake_model = model.as_fake();
2202
2203 let events = thread
2204 .update(cx, |thread, cx| {
2205 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2206 thread.add_tool(EchoTool);
2207 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2208 })
2209 .unwrap();
2210 cx.run_until_parked();
2211
2212 let tool_use_1 = LanguageModelToolUse {
2213 id: "tool_1".into(),
2214 name: EchoTool::name().into(),
2215 raw_input: json!({"text": "test"}).to_string(),
2216 input: json!({"text": "test"}),
2217 is_input_complete: true,
2218 };
2219 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2220 tool_use_1.clone(),
2221 ));
2222 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2223 provider: LanguageModelProviderName::new("Anthropic"),
2224 retry_after: Some(Duration::from_secs(3)),
2225 });
2226 fake_model.end_last_completion_stream();
2227
2228 cx.executor().advance_clock(Duration::from_secs(3));
2229 let completion = fake_model.pending_completions().pop().unwrap();
2230 assert_eq!(
2231 completion.messages[1..],
2232 vec![
2233 LanguageModelRequestMessage {
2234 role: Role::User,
2235 content: vec!["Call the echo tool!".into()],
2236 cache: false
2237 },
2238 LanguageModelRequestMessage {
2239 role: Role::Assistant,
2240 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2241 cache: false
2242 },
2243 LanguageModelRequestMessage {
2244 role: Role::User,
2245 content: vec![language_model::MessageContent::ToolResult(
2246 LanguageModelToolResult {
2247 tool_use_id: tool_use_1.id.clone(),
2248 tool_name: tool_use_1.name.clone(),
2249 is_error: false,
2250 content: "test".into(),
2251 output: Some("test".into())
2252 }
2253 )],
2254 cache: true
2255 },
2256 ]
2257 );
2258
2259 fake_model.send_last_completion_stream_text_chunk("Done");
2260 fake_model.end_last_completion_stream();
2261 cx.run_until_parked();
2262 events.collect::<Vec<_>>().await;
2263 thread.read_with(cx, |thread, _cx| {
2264 assert_eq!(
2265 thread.last_message(),
2266 Some(Message::Agent(AgentMessage {
2267 content: vec![AgentMessageContent::Text("Done".into())],
2268 tool_results: IndexMap::default()
2269 }))
2270 );
2271 })
2272}
2273
2274#[gpui::test]
2275async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2276 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2277 let fake_model = model.as_fake();
2278
2279 let mut events = thread
2280 .update(cx, |thread, cx| {
2281 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2282 thread.send(UserMessageId::new(), ["Hello!"], cx)
2283 })
2284 .unwrap();
2285 cx.run_until_parked();
2286
2287 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2288 fake_model.send_last_completion_stream_error(
2289 LanguageModelCompletionError::ServerOverloaded {
2290 provider: LanguageModelProviderName::new("Anthropic"),
2291 retry_after: Some(Duration::from_secs(3)),
2292 },
2293 );
2294 fake_model.end_last_completion_stream();
2295 cx.executor().advance_clock(Duration::from_secs(3));
2296 cx.run_until_parked();
2297 }
2298
2299 let mut errors = Vec::new();
2300 let mut retry_events = Vec::new();
2301 while let Some(event) = events.next().await {
2302 match event {
2303 Ok(ThreadEvent::Retry(retry_status)) => {
2304 retry_events.push(retry_status);
2305 }
2306 Ok(ThreadEvent::Stop(..)) => break,
2307 Err(error) => errors.push(error),
2308 _ => {}
2309 }
2310 }
2311
2312 assert_eq!(
2313 retry_events.len(),
2314 crate::thread::MAX_RETRY_ATTEMPTS as usize
2315 );
2316 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2317 assert_eq!(retry_events[i].attempt, i + 1);
2318 }
2319 assert_eq!(errors.len(), 1);
2320 let error = errors[0]
2321 .downcast_ref::<LanguageModelCompletionError>()
2322 .unwrap();
2323 assert!(matches!(
2324 error,
2325 LanguageModelCompletionError::ServerOverloaded { .. }
2326 ));
2327}
2328
2329/// Filters out the stop events for asserting against in tests
2330fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2331 result_events
2332 .into_iter()
2333 .filter_map(|event| match event.unwrap() {
2334 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2335 _ => None,
2336 })
2337 .collect()
2338}
2339
2340struct ThreadTest {
2341 model: Arc<dyn LanguageModel>,
2342 thread: Entity<Thread>,
2343 project_context: Entity<ProjectContext>,
2344 context_server_store: Entity<ContextServerStore>,
2345 fs: Arc<FakeFs>,
2346}
2347
2348enum TestModel {
2349 Sonnet4,
2350 Fake,
2351}
2352
2353impl TestModel {
2354 fn id(&self) -> LanguageModelId {
2355 match self {
2356 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2357 TestModel::Fake => unreachable!(),
2358 }
2359 }
2360}
2361
2362async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2363 cx.executor().allow_parking();
2364
2365 let fs = FakeFs::new(cx.background_executor.clone());
2366 fs.create_dir(paths::settings_file().parent().unwrap())
2367 .await
2368 .unwrap();
2369 fs.insert_file(
2370 paths::settings_file(),
2371 json!({
2372 "agent": {
2373 "default_profile": "test-profile",
2374 "profiles": {
2375 "test-profile": {
2376 "name": "Test Profile",
2377 "tools": {
2378 EchoTool::name(): true,
2379 DelayTool::name(): true,
2380 WordListTool::name(): true,
2381 ToolRequiringPermission::name(): true,
2382 InfiniteTool::name(): true,
2383 ThinkingTool::name(): true,
2384 }
2385 }
2386 }
2387 }
2388 })
2389 .to_string()
2390 .into_bytes(),
2391 )
2392 .await;
2393
2394 cx.update(|cx| {
2395 settings::init(cx);
2396
2397 match model {
2398 TestModel::Fake => {}
2399 TestModel::Sonnet4 => {
2400 gpui_tokio::init(cx);
2401 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2402 cx.set_http_client(Arc::new(http_client));
2403 let client = Client::production(cx);
2404 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2405 language_model::init(client.clone(), cx);
2406 language_models::init(user_store, client.clone(), cx);
2407 }
2408 };
2409
2410 watch_settings(fs.clone(), cx);
2411 });
2412
2413 let templates = Templates::new();
2414
2415 fs.insert_tree(path!("/test"), json!({})).await;
2416 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2417
2418 let model = cx
2419 .update(|cx| {
2420 if let TestModel::Fake = model {
2421 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2422 } else {
2423 let model_id = model.id();
2424 let models = LanguageModelRegistry::read_global(cx);
2425 let model = models
2426 .available_models(cx)
2427 .find(|model| model.id() == model_id)
2428 .unwrap();
2429
2430 let provider = models.provider(&model.provider_id()).unwrap();
2431 let authenticated = provider.authenticate(cx);
2432
2433 cx.spawn(async move |_cx| {
2434 authenticated.await.unwrap();
2435 model
2436 })
2437 }
2438 })
2439 .await;
2440
2441 let project_context = cx.new(|_cx| ProjectContext::default());
2442 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2443 let context_server_registry =
2444 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2445 let thread = cx.new(|cx| {
2446 Thread::new(
2447 project,
2448 project_context.clone(),
2449 context_server_registry,
2450 templates,
2451 Some(model.clone()),
2452 cx,
2453 )
2454 });
2455 ThreadTest {
2456 model,
2457 thread,
2458 project_context,
2459 context_server_store,
2460 fs,
2461 }
2462}
2463
2464#[cfg(test)]
2465#[ctor::ctor]
2466fn init_logger() {
2467 if std::env::var("RUST_LOG").is_ok() {
2468 env_logger::init();
2469 }
2470}
2471
2472fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2473 let fs = fs.clone();
2474 cx.spawn({
2475 async move |cx| {
2476 let mut new_settings_content_rx = settings::watch_config_file(
2477 cx.background_executor(),
2478 fs,
2479 paths::settings_file().clone(),
2480 );
2481
2482 while let Some(new_settings_content) = new_settings_content_rx.next().await {
2483 cx.update(|cx| {
2484 SettingsStore::update_global(cx, |settings, cx| {
2485 settings.set_user_settings(&new_settings_content, cx)
2486 })
2487 })
2488 .ok();
2489 }
2490 }
2491 })
2492 .detach();
2493}
2494
2495fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2496 completion
2497 .tools
2498 .iter()
2499 .map(|tool| tool.name.clone())
2500 .collect()
2501}
2502
2503fn setup_context_server(
2504 name: &'static str,
2505 tools: Vec<context_server::types::Tool>,
2506 context_server_store: &Entity<ContextServerStore>,
2507 cx: &mut TestAppContext,
2508) -> mpsc::UnboundedReceiver<(
2509 context_server::types::CallToolParams,
2510 oneshot::Sender<context_server::types::CallToolResponse>,
2511)> {
2512 cx.update(|cx| {
2513 let mut settings = ProjectSettings::get_global(cx).clone();
2514 settings.context_servers.insert(
2515 name.into(),
2516 project::project_settings::ContextServerSettings::Custom {
2517 enabled: true,
2518 command: ContextServerCommand {
2519 path: "somebinary".into(),
2520 args: Vec::new(),
2521 env: None,
2522 timeout: None,
2523 },
2524 },
2525 );
2526 ProjectSettings::override_global(settings, cx);
2527 });
2528
2529 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2530 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2531 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2532 context_server::types::InitializeResponse {
2533 protocol_version: context_server::types::ProtocolVersion(
2534 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2535 ),
2536 server_info: context_server::types::Implementation {
2537 name: name.into(),
2538 version: "1.0.0".to_string(),
2539 },
2540 capabilities: context_server::types::ServerCapabilities {
2541 tools: Some(context_server::types::ToolsCapabilities {
2542 list_changed: Some(true),
2543 }),
2544 ..Default::default()
2545 },
2546 meta: None,
2547 }
2548 })
2549 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2550 let tools = tools.clone();
2551 async move {
2552 context_server::types::ListToolsResponse {
2553 tools,
2554 next_cursor: None,
2555 meta: None,
2556 }
2557 }
2558 })
2559 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2560 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2561 async move {
2562 let (response_tx, response_rx) = oneshot::channel();
2563 mcp_tool_calls_tx
2564 .unbounded_send((params, response_tx))
2565 .unwrap();
2566 response_rx.await.unwrap()
2567 }
2568 });
2569 context_server_store.update(cx, |store, cx| {
2570 store.start_server(
2571 Arc::new(ContextServer::new(
2572 ContextServerId(name.into()),
2573 Arc::new(fake_transport),
2574 )),
2575 cx,
2576 );
2577 });
2578 cx.run_until_parked();
2579 mcp_tool_calls_rx
2580}
2581
2582/// Tests that rules toggled after thread creation are applied to the first message.
2583///
2584/// This test verifies the fix for https://github.com/zed-industries/zed/issues/39057
2585/// where rules toggled in the Rules Library after opening a new agent thread were not
2586/// being applied to the first message sent in that thread.
2587///
2588/// The test simulates:
2589/// 1. Creating a thread with an initial rule in the project context
2590/// 2. Updating the project context entity with new rules (simulating what
2591/// NativeAgent.maintain_project_context does when rules are toggled)
2592/// 3. Sending a message through the thread
2593/// 4. Verifying that the newly toggled rules appear in the system prompt
2594///
2595/// The fix ensures that threads see updated rules by updating the project_context
2596/// entity in place rather than creating a new entity that threads wouldn't reference.
2597#[gpui::test]
2598async fn test_rules_toggled_after_thread_creation_are_applied(cx: &mut TestAppContext) {
2599 let ThreadTest {
2600 model,
2601 thread,
2602 project_context,
2603 ..
2604 } = setup(cx, TestModel::Fake).await;
2605 let fake_model = model.as_fake();
2606
2607 let initial_rule_content = "Initial rule before thread creation.";
2608 project_context.update(cx, |context, _cx| {
2609 context.user_rules.push(UserRulesContext {
2610 uuid: UserPromptId::new(),
2611 title: Some("Initial Rule".to_string()),
2612 contents: initial_rule_content.to_string(),
2613 });
2614 context.has_user_rules = true;
2615 });
2616
2617 let rule_id = UserPromptId::new();
2618 let new_rule_content = "Always respond in uppercase.";
2619
2620 project_context.update(cx, |context, _cx| {
2621 context.user_rules.clear();
2622 context.user_rules.push(UserRulesContext {
2623 uuid: rule_id,
2624 title: Some("New Rule".to_string()),
2625 contents: new_rule_content.to_string(),
2626 });
2627 context.has_user_rules = true;
2628 });
2629
2630 thread
2631 .update(cx, |thread, cx| {
2632 thread.send(UserMessageId::new(), ["test message"], cx)
2633 })
2634 .unwrap();
2635
2636 cx.run_until_parked();
2637
2638 let mut pending_completions = fake_model.pending_completions();
2639 assert_eq!(pending_completions.len(), 1);
2640
2641 let pending_completion = pending_completions.pop().unwrap();
2642 let system_message = &pending_completion.messages[0];
2643 let system_prompt = system_message.content[0].to_str().unwrap();
2644
2645 assert!(
2646 system_prompt.contains(new_rule_content),
2647 "System prompt should contain the rule content that was toggled after thread creation"
2648 );
2649}